From ae9b4e5ea99673e4713deabd17d594cb5f9db4a0 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 16 Oct 2024 11:36:41 -0400 Subject: [PATCH 01/53] init --- .../ebpf/c/protocols/classification/defs.h | 2 +- .../classification/protocol-classification.h | 2 +- .../classification/shared-tracer-maps.h | 6 + pkg/network/ebpf/c/protocols/tls/tls.h | 542 ++++++++++++++++-- pkg/network/ebpf/c/tracer/stats.h | 2 + 5 files changed, 513 insertions(+), 41 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/defs.h b/pkg/network/ebpf/c/protocols/classification/defs.h index 823112a4fb7e10..4bcd7593dbb601 100644 --- a/pkg/network/ebpf/c/protocols/classification/defs.h +++ b/pkg/network/ebpf/c/protocols/classification/defs.h @@ -103,7 +103,7 @@ typedef struct { // `protocol_stack_t` is embedded in the `conn_stats_t` type, which is used // across the whole NPM kernel code. If we added the 64-bit timestamp field // directly to `protocol_stack_t`, we would go from 4 bytes to 12 bytes, which -// bloats the eBPF stack size of some NPM probes. Using the wrapper type +// bloats the eBPF stack size of some NPM probes. Using the wrapper type // prevents that, because we pretty much only store the wrapper type in the // connection_protocol map, but elsewhere in the code we're still using // protocol_stack_t, so this is change is "transparent" to most of the code. diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index a235138a5c7184..160c80d9e47219 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -169,7 +169,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct protocol_t app_layer_proto = get_protocol_from_stack(protocol_stack, LAYER_APPLICATION); - if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(buffer, usm_ctx->buffer.size, skb_info.data_end)) { + if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb)) { // TLS classification update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); // The connection is TLS encrypted, thus we cannot classify the protocol diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index ebc1881c1f3c21..f5716e821acf00 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -9,6 +9,8 @@ // classification procedures on the same connection BPF_HASH_MAP(connection_protocol, conn_tuple_t, protocol_stack_wrapper_t, 0) +BPF_HASH_MAP(tls_expanded_tags, conn_tuple_t, tls_expanded_tags_t, 0) + static __always_inline protocol_stack_t* __get_protocol_stack(conn_tuple_t* tuple) { protocol_stack_wrapper_t *wrapper = bpf_map_lookup_elem(&connection_protocol, tuple); if (!wrapper) { @@ -17,6 +19,10 @@ static __always_inline protocol_stack_t* __get_protocol_stack(conn_tuple_t* tupl return &wrapper->stack; } +static __always_inline tls_expanded_tags_t* get_tls_expanded_tags(conn_tuple_t* tuple) { + return bpf_map_lookup_elem(&tls_expanded_tags, tuple); +} + static __always_inline protocol_stack_t* get_protocol_stack(conn_tuple_t *skb_tup) { conn_tuple_t normalized_tup = *skb_tup; normalize_tuple(&normalized_tup); diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 0f7f2cdf9ee56b..938125bd12c2db 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -4,6 +4,13 @@ #include "ktypes.h" #include "bpf_builtins.h" +// #include // For Ethernet header structures +// #include // For TCP header structures +// #include // For IP header structures +#include "ip.h" + +#define ETH_HLEN 14 // Ethernet header length + #define SSL_VERSION20 0x0200 #define SSL_VERSION30 0x0300 #define TLS_VERSION10 0x0301 @@ -31,6 +38,21 @@ typedef struct { __u16 version; } __attribute__((packed)) tls_hello_message_t; +typedef struct { + __u16 offered_versions[6]; +} tls_client_tags_t; + +typedef struct { + __u16 version; + __u16 cipher_suite; + __u8 compression_method; +} tls_server_tags_t; + +typedef struct { + tls_client_tags_t client_tags; + tls_server_tags_t server_tags; +} tls_expanded_tags_t; + #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 // The size of the handshake type and the length. @@ -59,74 +81,516 @@ static __always_inline bool is_valid_tls_version(__u16 version) { // standard. // - the payload length + the size of the record header is less than the size // of the skb -static __always_inline bool is_valid_tls_app_data(tls_record_header_t *hdr, __u32 buf_size, __u32 skb_len) { - return sizeof(*hdr) + hdr->length <= skb_len; +static __always_inline bool is_valid_tls_app_data(tls_record_header_t *hdr, __u32 skb_len) { + return hdr->length + sizeof(tls_record_header_t) <= skb_len; } // is_tls_handshake checks if the given TLS message header is a valid TLS // handshake message. The message is considered valid if: // - The type matches CLIENT_HELLO or SERVER_HELLO // - The version is a known SSL/TLS version -static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, const char *buf, __u32 buf_size) { - // Checking the buffer size contains at least the size of the tls record header and the tls hello message header. - if (sizeof(tls_record_header_t) + sizeof(tls_hello_message_t) > buf_size) { +// static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, const char *buf, __u32 buf_size) { +// // Checking the buffer size contains at least the size of the tls record header and the tls hello message header. +// if (sizeof(tls_record_header_t) + sizeof(tls_hello_message_t) > buf_size) { +// return false; +// } +// // Checking the tls record header length is greater than the tls hello message header length. +// if (hdr->length < sizeof(tls_hello_message_t)) { +// return false; +// } + +// // Getting the tls hello message header. +// tls_hello_message_t msg = *(tls_hello_message_t *)(buf + sizeof(tls_record_header_t)); +// // If the message is not a CLIENT_HELLO or SERVER_HELLO, we don't attempt to classify. +// if (msg.handshake_type != TLS_HANDSHAKE_CLIENT_HELLO && msg.handshake_type != TLS_HANDSHAKE_SERVER_HELLO) { +// return false; +// } +// // Converting the fields to host byte order. +// __u32 length = msg.length[0] << 16 | msg.length[1] << 8 | msg.length[2]; +// // TLS handshake message length should be equal to the record header length minus the size of the hello message +// // header. +// if (length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length) { +// return false; +// } + +// msg.version = bpf_ntohs(msg.version); +// return is_valid_tls_version(msg.version) && msg.version >= hdr->version; +// } + +// // is_tls checks if the given buffer is a valid TLS record header. We are +// // currently checking for two types of record headers: +// // - TLS Handshake record headers +// // - TLS Application Data record headers +// static __always_inline bool is_tls(const char *buf, __u32 buf_size, __u32 skb_len) { +// if (buf_size < sizeof(tls_record_header_t)) { +// return false; +// } + +// // Copying struct to the stack, to avoid modifying the original buffer that will be used for other classifiers. +// tls_record_header_t tls_record_header = *(tls_record_header_t *)buf; +// // Converting the fields to host byte order. +// tls_record_header.version = bpf_ntohs(tls_record_header.version); +// tls_record_header.length = bpf_ntohs(tls_record_header.length); + +// // Checking the version in the record header. +// if (!is_valid_tls_version(tls_record_header.version)) { +// return false; +// } + +// // Checking the length in the record header is not greater than the maximum payload length. +// if (tls_record_header.length > TLS_MAX_PAYLOAD_LENGTH) { +// return false; +// } +// switch (tls_record_header.content_type) { +// case TLS_HANDSHAKE: +// return is_tls_handshake(&tls_record_header, buf, buf_size); +// case TLS_APPLICATION_DATA: +// return is_valid_tls_app_data(&tls_record_header, skb_len); +// } + +// return false; +// } + +static __always_inline int parse_ethernet(struct __sk_buff *skb, __u64 nh_off, __u16 *eth_proto) { + struct ethhdr eth; + + // Ensure there's enough data for the Ethernet header + if (nh_off + sizeof(struct ethhdr) > skb->len) + return -1; + + // Load the Ethernet header from the packet + if (bpf_skb_load_bytes(skb, nh_off, ð, sizeof(eth)) < 0) + return -1; + + // Extract the EtherType (protocol) + *eth_proto = bpf_ntohs(eth.h_proto); + + return 0; +} + +static __always_inline int parse_tcp(struct __sk_buff *skb, __u64 nh_off, __u64 *tcp_hdr_len) { + struct tcphdr tcp; + + // Ensure there's enough data for the TCP header (minimum 20 bytes) + if (nh_off + sizeof(struct tcphdr) > skb->len) + return -1; + + // Load the TCP header from the packet + if (bpf_skb_load_bytes(skb, nh_off, &tcp, sizeof(tcp)) < 0) + return -1; + + // Extract the Data Offset (Header Length) + // The data offset field specifies the size of the TCP header in 32-bit words + __u8 data_offset = tcp.doff; + + // Calculate the TCP header length in bytes + *tcp_hdr_len = (__u64)data_offset * 4; + + // Ensure that the computed TCP header length is valid + if (*tcp_hdr_len < sizeof(struct tcphdr)) + return -1; + if (nh_off + *tcp_hdr_len > skb->len) + return -1; + + return 0; +} + +static __always_inline int parse_ip(struct __sk_buff *skb, __u64 nh_off, __u8 *protocol, __u64 *ip_hdr_len) { + struct iphdr ip; + + // Ensure there's enough data for the IP header (minimum 20 bytes) + if (nh_off + sizeof(struct iphdr) > skb->len) + return -1; + + // Load IP header from the packet + if (bpf_skb_load_bytes(skb, nh_off, &ip, sizeof(ip)) < 0) + return -1; + + // Extract the Internet Header Length (IHL) + // The IHL field specifies the size of the IP header in 32-bit words + __u8 ihl = ip.ihl; + + // Calculate the IP header length in bytes + *ip_hdr_len = (__u64)ihl * 4; + + // Ensure that the computed IP header length is valid + if (*ip_hdr_len < sizeof(struct iphdr)) + return -1; + if (nh_off + *ip_hdr_len > skb->len) + return -1; + + // Extract the protocol field (e.g., TCP, UDP) + *protocol = ip.protocol; + + return 0; +} + +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 offset, __u16 extension_length) { + __u64 skb_len = skb->len; + + // Ensure we have at least 1 byte for the list length + if (offset + 1 > skb_len) { return false; } - // Checking the tls record header length is greater than the tls hello message header length. - if (hdr->length < sizeof(tls_hello_message_t)) { + + // Read Supported Versions Length (1 byte) + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) { return false; } + offset += 1; - // Getting the tls hello message header. - tls_hello_message_t msg = *(tls_hello_message_t *)(buf + sizeof(tls_record_header_t)); - // If the message is not a CLIENT_HELLO or SERVER_HELLO, we don't attempt to classify. - if (msg.handshake_type != TLS_HANDSHAKE_CLIENT_HELLO && msg.handshake_type != TLS_HANDSHAKE_SERVER_HELLO) { + // Ensure the list length is consistent with the extension length + if (sv_list_length + 1 > extension_length) { return false; } - // Converting the fields to host byte order. - __u32 length = msg.length[0] << 16 | msg.length[1] << 8 | msg.length[2]; - // TLS handshake message length should be equal to the record header length minus the size of the hello message - // header. - if (length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length) { + + // Ensure we don't read beyond the packet + if (offset + sv_list_length > skb_len) { return false; } - msg.version = bpf_ntohs(msg.version); - return is_valid_tls_version(msg.version) && msg.version >= hdr->version; + // Set an upper bound for the loop to satisfy the eBPF verifier + #define MAX_SUPPORTED_VERSIONS 8 + __u8 versions_parsed = 0; + + // Read the list of supported versions (2 bytes each) + for (__u8 i = 0; i + 1 < sv_list_length && versions_parsed < MAX_SUPPORTED_VERSIONS; i += 2, versions_parsed++) { + __u16 sv_version; + if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) { + return false; + } + sv_version = bpf_ntohs(sv_version); + offset += 2; + + if (sv_version == TLS_VERSION13) { + log_debug("adamk supported version: TLS 1.3"); + } + + // TODO: add supported version to the map + + } + log_debug("adamk supported versions parsed: %d", versions_parsed); + return true; +} + +static __always_inline bool parse_client_hello_extensions(struct __sk_buff *skb, __u64 offset, __u16 extensions_length) { + __u64 skb_len = skb->len; + __u64 extensions_end = offset + extensions_length; + + // Set an upper bound for the loop to satisfy the eBPF verifier + #define MAX_EXTENSIONS 16 + __u8 extensions_parsed = 0; + + while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { + // Read Extension Type (2 bytes) + __u16 extension_type; + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) { + return false; + } + extension_type = bpf_ntohs(extension_type); + offset += 2; + + // Read Extension Length (2 bytes) + __u16 extension_length; + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) { + return false; + } + extension_length = bpf_ntohs(extension_length); + offset += 2; + + // Ensure we don't read beyond the packet + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) { + return false; + } + + // Check for supported_versions extension (type 0x002B) + if (extension_type == 0x002B) { + if (!parse_supported_versions_extension(skb, offset, extension_length)) { + return false; + } + } + + // Skip to the next extension + offset += extension_length; + extensions_parsed++; + } + + return true; } -// is_tls checks if the given buffer is a valid TLS record header. We are -// currently checking for two types of record headers: -// - TLS Handshake record headers -// - TLS Application Data record headers -static __always_inline bool is_tls(const char *buf, __u32 buf_size, __u32 skb_len) { - if (buf_size < sizeof(tls_record_header_t)) { +static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { + __u32 skb_len = skb->len; + + // Move offset past handshake type (1 byte) + offset += 1; + + // Read handshake length (3 bytes) + __u8 handshake_length_bytes[3]; + if (bpf_skb_load_bytes(skb, offset, handshake_length_bytes, 3) < 0) + return false; + __u32 handshake_length = (handshake_length_bytes[0] << 16) | + (handshake_length_bytes[1] << 8) | + handshake_length_bytes[2]; + offset += 3; + + // Ensure we don't read beyond the packet + if (offset + handshake_length > skb_len) + return false; + + // Read client version (2 bytes) + __u16 client_version; + if (bpf_skb_load_bytes(skb, offset, &client_version, sizeof(client_version)) < 0) + return false; + client_version = bpf_ntohs(client_version); + log_debug("adamk client version: %d", client_version); + offset += 2; + + // Validate client version + if (!is_valid_tls_version(client_version)) + return false; + + // Skip Random (32 bytes) + offset += 32; + + // Read Session ID Length (1 byte) + __u8 session_id_len; + if (bpf_skb_load_bytes(skb, offset, &session_id_len, sizeof(session_id_len)) < 0) + return false; + offset += 1; + + // Skip Session ID + offset += session_id_len; + + // Ensure we don't read beyond the packet + if (offset + 2 > skb_len) + return false; + + // Read Cipher Suites Length (2 bytes) + __u16 cipher_suites_length; + if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, sizeof(cipher_suites_length)) < 0) + return false; + cipher_suites_length = bpf_ntohs(cipher_suites_length); + log_debug("adamk client cipher_suites_length: %d", cipher_suites_length); + offset += 2; + + // Ensure we don't read beyond the packet + if (offset + cipher_suites_length > skb_len) + return false; + + // Skip Cipher Suites + offset += cipher_suites_length; + + // Read Compression Methods Length (1 byte) + if (offset + 1 > skb_len) + return false; + __u8 compression_methods_length; + if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, sizeof(compression_methods_length)) < 0) + return false; + offset += 1; + + // Skip Compression Methods + offset += compression_methods_length; + + // Read Extensions Length (2 bytes) + if (offset + 2 > skb_len) { return false; } + __u16 extensions_length; + if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) { + return false; + } + extensions_length = bpf_ntohs(extensions_length); + offset += 2; - // Copying struct to the stack, to avoid modifying the original buffer that will be used for other classifiers. - tls_record_header_t tls_record_header = *(tls_record_header_t *)buf; - // Converting the fields to host byte order. - tls_record_header.version = bpf_ntohs(tls_record_header.version); - tls_record_header.length = bpf_ntohs(tls_record_header.length); + // Ensure we don't read beyond the packet + if (offset + extensions_length > skb_len) { + return false; + } - // Checking the version in the record header. - if (!is_valid_tls_version(tls_record_header.version)) { + // Parse Extensions + if (!parse_client_hello_extensions(skb, offset, extensions_length)) { return false; } - // Checking the length in the record header is not greater than the maximum payload length. - if (tls_record_header.length > TLS_MAX_PAYLOAD_LENGTH) { + // At this point, we've successfully parsed the ClientHello message + return true; +} + +static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { + __u32 skb_len = skb->len; + + // Move offset past handshake type (1 byte) + offset += 1; + + // Read handshake length (3 bytes) + __u8 handshake_length_bytes[3]; + if (bpf_skb_load_bytes(skb, offset, handshake_length_bytes, 3) < 0) + return false; + __u32 handshake_length = (handshake_length_bytes[0] << 16) | + (handshake_length_bytes[1] << 8) | + handshake_length_bytes[2]; + offset += 3; + + // Ensure we don't read beyond the packet + if (offset + handshake_length > skb_len) + return false; + + // Read server version (2 bytes) + __u16 server_version; + if (bpf_skb_load_bytes(skb, offset, &server_version, sizeof(server_version)) < 0) + return false; + server_version = bpf_ntohs(server_version); + log_debug("adamk server version: %d", server_version); + offset += 2; + + // Validate server version + if (!is_valid_tls_version(server_version)) + return false; + + // Skip Random (32 bytes) + offset += 32; + + // Read Session ID Length (1 byte) + __u8 session_id_len; + if (bpf_skb_load_bytes(skb, offset, &session_id_len, sizeof(session_id_len)) < 0) + return false; + offset += 1; + + // Skip Session ID + offset += session_id_len; + + // Ensure we don't read beyond the packet + if (offset + 2 > skb_len) + return false; + + // Read Cipher Suite (2 bytes) + __u16 cipher_suite; + if (bpf_skb_load_bytes(skb, offset, &cipher_suite, sizeof(cipher_suite)) < 0) + return false; + cipher_suite = bpf_ntohs(cipher_suite); + log_debug("adamk server cipher_suite: %d", cipher_suite); + offset += 2; + + // You can store or process the cipher suite as needed + + // Read Compression Method (1 byte) + if (offset + 1 > skb_len) return false; + __u8 compression_method; + if (bpf_skb_load_bytes(skb, offset, &compression_method, sizeof(compression_method)) < 0) + return false; + offset += 1; + + // Read Extensions Length (2 bytes) if present + if (offset + 2 <= skb_len) { + __u16 extensions_length; + if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) + return false; + extensions_length = bpf_ntohs(extensions_length); + offset += 2; + + // Ensure we don't read beyond the packet + if (offset + extensions_length > skb_len) + return false; + + // Process extensions if needed } - switch (tls_record_header.content_type) { - case TLS_HANDSHAKE: - return is_tls_handshake(&tls_record_header, buf, buf_size); - case TLS_APPLICATION_DATA: - return is_valid_tls_app_data(&tls_record_header, buf_size, skb_len); + + // At this point, we've successfully parsed the ServerHello message + return true; +} + +static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { + // Read handshake type + __u8 handshake_type; + if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) + return false; + + // Only proceed if it's a ClientHello or ServerHello + if (handshake_type == TLS_HANDSHAKE_CLIENT_HELLO) { + log_debug("adamk inspecting client hello"); + return parse_client_hello(hdr, skb, offset); + } else if (handshake_type == TLS_HANDSHAKE_SERVER_HELLO) { + log_debug("adamk inspecting server hello"); + return parse_server_hello(hdr, skb, offset); + } else { + return false; } +} - return false; +static __always_inline bool is_tls(struct __sk_buff *skb) { + __u64 nh_off = 0; + __u32 skb_len = skb->len; + __u16 eth_proto = 0; + + // Parse Ethernet header + if (parse_ethernet(skb, nh_off, ð_proto) < 0) + return false; + nh_off += ETH_HLEN; + + // Parse IP header + if (eth_proto == ETH_P_IP) { + __u8 ip_proto = 0; + __u64 ip_hdr_len = 0; + if (parse_ip(skb, nh_off, &ip_proto, &ip_hdr_len) < 0) + return false; + nh_off += ip_hdr_len; + + if (ip_proto != IPPROTO_TCP) + return false; + } else if (eth_proto == ETH_P_IPV6) { + // Parse IPv6 header (left as an exercise) + return false; + } else { + return false; + } + + // Parse TCP header + __u64 tcp_hdr_len = 0; + if (parse_tcp(skb, nh_off, &tcp_hdr_len) < 0) + return false; + nh_off += tcp_hdr_len; + + // Ensure there's enough space for TLS record header + if (nh_off + sizeof(tls_record_header_t) > skb_len) + return false; + + // Read TLS record header + tls_record_header_t tls_hdr; + if (bpf_skb_load_bytes(skb, nh_off, &tls_hdr, sizeof(tls_hdr)) < 0) + return false; + + // Convert fields to host byte order + tls_hdr.version = bpf_ntohs(tls_hdr.version); + tls_hdr.length = bpf_ntohs(tls_hdr.length); + + // Validate version and length + if (!is_valid_tls_version(tls_hdr.version)) + return false; + if (tls_hdr.length > TLS_MAX_PAYLOAD_LENGTH) + return false; + + // Move offset to the start of TLS handshake message + nh_off += sizeof(tls_record_header_t); + + // Ensure we don't read beyond the packet + if (nh_off + tls_hdr.length > skb_len) + return false; + + // Handle based on content type + switch (tls_hdr.content_type) { + case TLS_HANDSHAKE: { + // return is_tls_handshake(&tls_hdr, skb, nh_off); + bool handshake = is_tls_handshake(&tls_hdr, skb, nh_off); + log_debug("adamk is_tls_handshake: %d", handshake); + return handshake; + } + case TLS_APPLICATION_DATA: + return is_valid_tls_app_data(&tls_hdr, skb_len); + default: + return false; + } } #endif diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index cb768fa71e3e4e..c353529d566941 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -120,6 +120,8 @@ static __always_inline void update_protocol_classification_information(conn_tupl set_protocol_flag(protocol_stack, FLAG_NPM_ENABLED); mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); + // lookup from new map and add info to the connection + tls_expanded_tags_t tls_tags = get_tls_expanded_tags(&conn_tuple_copy); } static __always_inline void determine_connection_direction(conn_tuple_t *t, conn_stats_ts_t *conn_stats) { From a2ec9584bca30ecfe9b78983ae7b3c7328b01a8a Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 17 Oct 2024 17:47:33 -0400 Subject: [PATCH 02/53] rework collection logic to use existing parsing --- .../classification/protocol-classification.h | 2 +- .../classification/shared-tracer-maps.h | 7 +- pkg/network/ebpf/c/protocols/tls/tls.h | 183 +------------ pkg/network/ebpf/c/tracer/stats.h | 2 +- pkg/network/ebpf/probes/probes.go | 1 + pkg/network/tracer/connection/ebpf_tracer.go | 1 + pkg/network/tracer/tracer_test.go | 247 ++++++++++++++++++ 7 files changed, 263 insertions(+), 180 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 160c80d9e47219..a2641ff91cc127 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -169,7 +169,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct protocol_t app_layer_proto = get_protocol_from_stack(protocol_stack, LAYER_APPLICATION); - if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb)) { + if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off)) { // TLS classification update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); // The connection is TLS encrypted, thus we cannot classify the protocol diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index f5716e821acf00..2f6da79cbbcae1 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -4,12 +4,13 @@ #include "map-defs.h" #include "port_range.h" #include "protocols/classification/stack-helpers.h" +#include "protocols/tls/tls.h" // Maps a connection tuple to its classified protocol. Used to reduce redundant // classification procedures on the same connection BPF_HASH_MAP(connection_protocol, conn_tuple_t, protocol_stack_wrapper_t, 0) -BPF_HASH_MAP(tls_expanded_tags, conn_tuple_t, tls_expanded_tags_t, 0) +BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_enhanced_tags_t, 0) static __always_inline protocol_stack_t* __get_protocol_stack(conn_tuple_t* tuple) { protocol_stack_wrapper_t *wrapper = bpf_map_lookup_elem(&connection_protocol, tuple); @@ -19,8 +20,8 @@ static __always_inline protocol_stack_t* __get_protocol_stack(conn_tuple_t* tupl return &wrapper->stack; } -static __always_inline tls_expanded_tags_t* get_tls_expanded_tags(conn_tuple_t* tuple) { - return bpf_map_lookup_elem(&tls_expanded_tags, tuple); +static __always_inline tls_enhanced_tags_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { + return bpf_map_lookup_elem(&tls_enhanced_tags, tuple); } static __always_inline protocol_stack_t* get_protocol_stack(conn_tuple_t *skb_tup) { diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 938125bd12c2db..67aa6d7494f2fe 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -4,9 +4,6 @@ #include "ktypes.h" #include "bpf_builtins.h" -// #include // For Ethernet header structures -// #include // For TCP header structures -// #include // For IP header structures #include "ip.h" #define ETH_HLEN 14 // Ethernet header length @@ -51,7 +48,7 @@ typedef struct { typedef struct { tls_client_tags_t client_tags; tls_server_tags_t server_tags; -} tls_expanded_tags_t; +} tls_enhanced_tags_t; #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 @@ -85,146 +82,6 @@ static __always_inline bool is_valid_tls_app_data(tls_record_header_t *hdr, __u3 return hdr->length + sizeof(tls_record_header_t) <= skb_len; } -// is_tls_handshake checks if the given TLS message header is a valid TLS -// handshake message. The message is considered valid if: -// - The type matches CLIENT_HELLO or SERVER_HELLO -// - The version is a known SSL/TLS version -// static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, const char *buf, __u32 buf_size) { -// // Checking the buffer size contains at least the size of the tls record header and the tls hello message header. -// if (sizeof(tls_record_header_t) + sizeof(tls_hello_message_t) > buf_size) { -// return false; -// } -// // Checking the tls record header length is greater than the tls hello message header length. -// if (hdr->length < sizeof(tls_hello_message_t)) { -// return false; -// } - -// // Getting the tls hello message header. -// tls_hello_message_t msg = *(tls_hello_message_t *)(buf + sizeof(tls_record_header_t)); -// // If the message is not a CLIENT_HELLO or SERVER_HELLO, we don't attempt to classify. -// if (msg.handshake_type != TLS_HANDSHAKE_CLIENT_HELLO && msg.handshake_type != TLS_HANDSHAKE_SERVER_HELLO) { -// return false; -// } -// // Converting the fields to host byte order. -// __u32 length = msg.length[0] << 16 | msg.length[1] << 8 | msg.length[2]; -// // TLS handshake message length should be equal to the record header length minus the size of the hello message -// // header. -// if (length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length) { -// return false; -// } - -// msg.version = bpf_ntohs(msg.version); -// return is_valid_tls_version(msg.version) && msg.version >= hdr->version; -// } - -// // is_tls checks if the given buffer is a valid TLS record header. We are -// // currently checking for two types of record headers: -// // - TLS Handshake record headers -// // - TLS Application Data record headers -// static __always_inline bool is_tls(const char *buf, __u32 buf_size, __u32 skb_len) { -// if (buf_size < sizeof(tls_record_header_t)) { -// return false; -// } - -// // Copying struct to the stack, to avoid modifying the original buffer that will be used for other classifiers. -// tls_record_header_t tls_record_header = *(tls_record_header_t *)buf; -// // Converting the fields to host byte order. -// tls_record_header.version = bpf_ntohs(tls_record_header.version); -// tls_record_header.length = bpf_ntohs(tls_record_header.length); - -// // Checking the version in the record header. -// if (!is_valid_tls_version(tls_record_header.version)) { -// return false; -// } - -// // Checking the length in the record header is not greater than the maximum payload length. -// if (tls_record_header.length > TLS_MAX_PAYLOAD_LENGTH) { -// return false; -// } -// switch (tls_record_header.content_type) { -// case TLS_HANDSHAKE: -// return is_tls_handshake(&tls_record_header, buf, buf_size); -// case TLS_APPLICATION_DATA: -// return is_valid_tls_app_data(&tls_record_header, skb_len); -// } - -// return false; -// } - -static __always_inline int parse_ethernet(struct __sk_buff *skb, __u64 nh_off, __u16 *eth_proto) { - struct ethhdr eth; - - // Ensure there's enough data for the Ethernet header - if (nh_off + sizeof(struct ethhdr) > skb->len) - return -1; - - // Load the Ethernet header from the packet - if (bpf_skb_load_bytes(skb, nh_off, ð, sizeof(eth)) < 0) - return -1; - - // Extract the EtherType (protocol) - *eth_proto = bpf_ntohs(eth.h_proto); - - return 0; -} - -static __always_inline int parse_tcp(struct __sk_buff *skb, __u64 nh_off, __u64 *tcp_hdr_len) { - struct tcphdr tcp; - - // Ensure there's enough data for the TCP header (minimum 20 bytes) - if (nh_off + sizeof(struct tcphdr) > skb->len) - return -1; - - // Load the TCP header from the packet - if (bpf_skb_load_bytes(skb, nh_off, &tcp, sizeof(tcp)) < 0) - return -1; - - // Extract the Data Offset (Header Length) - // The data offset field specifies the size of the TCP header in 32-bit words - __u8 data_offset = tcp.doff; - - // Calculate the TCP header length in bytes - *tcp_hdr_len = (__u64)data_offset * 4; - - // Ensure that the computed TCP header length is valid - if (*tcp_hdr_len < sizeof(struct tcphdr)) - return -1; - if (nh_off + *tcp_hdr_len > skb->len) - return -1; - - return 0; -} - -static __always_inline int parse_ip(struct __sk_buff *skb, __u64 nh_off, __u8 *protocol, __u64 *ip_hdr_len) { - struct iphdr ip; - - // Ensure there's enough data for the IP header (minimum 20 bytes) - if (nh_off + sizeof(struct iphdr) > skb->len) - return -1; - - // Load IP header from the packet - if (bpf_skb_load_bytes(skb, nh_off, &ip, sizeof(ip)) < 0) - return -1; - - // Extract the Internet Header Length (IHL) - // The IHL field specifies the size of the IP header in 32-bit words - __u8 ihl = ip.ihl; - - // Calculate the IP header length in bytes - *ip_hdr_len = (__u64)ihl * 4; - - // Ensure that the computed IP header length is valid - if (*ip_hdr_len < sizeof(struct iphdr)) - return -1; - if (nh_off + *ip_hdr_len > skb->len) - return -1; - - // Extract the protocol field (e.g., TCP, UDP) - *protocol = ip.protocol; - - return 0; -} - static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 offset, __u16 extension_length) { __u64 skb_len = skb->len; @@ -519,38 +376,8 @@ static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __ } } -static __always_inline bool is_tls(struct __sk_buff *skb) { - __u64 nh_off = 0; +static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off) { __u32 skb_len = skb->len; - __u16 eth_proto = 0; - - // Parse Ethernet header - if (parse_ethernet(skb, nh_off, ð_proto) < 0) - return false; - nh_off += ETH_HLEN; - - // Parse IP header - if (eth_proto == ETH_P_IP) { - __u8 ip_proto = 0; - __u64 ip_hdr_len = 0; - if (parse_ip(skb, nh_off, &ip_proto, &ip_hdr_len) < 0) - return false; - nh_off += ip_hdr_len; - - if (ip_proto != IPPROTO_TCP) - return false; - } else if (eth_proto == ETH_P_IPV6) { - // Parse IPv6 header (left as an exercise) - return false; - } else { - return false; - } - - // Parse TCP header - __u64 tcp_hdr_len = 0; - if (parse_tcp(skb, nh_off, &tcp_hdr_len) < 0) - return false; - nh_off += tcp_hdr_len; // Ensure there's enough space for TLS record header if (nh_off + sizeof(tls_record_header_t) > skb_len) @@ -565,6 +392,12 @@ static __always_inline bool is_tls(struct __sk_buff *skb) { tls_hdr.version = bpf_ntohs(tls_hdr.version); tls_hdr.length = bpf_ntohs(tls_hdr.length); + // Validate version and length + if (!is_valid_tls_version(tls_hdr.version)) + return false; + if (tls_hdr.length > TLS_MAX_PAYLOAD_LENGTH) + return false; + // Validate version and length if (!is_valid_tls_version(tls_hdr.version)) return false; diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index c353529d566941..03ab116b8972e8 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -121,7 +121,7 @@ static __always_inline void update_protocol_classification_information(conn_tupl mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); // lookup from new map and add info to the connection - tls_expanded_tags_t tls_tags = get_tls_expanded_tags(&conn_tuple_copy); + // tls_enhanced_tags_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); } static __always_inline void determine_connection_direction(conn_tuple_t *t, conn_stats_ts_t *conn_stats) { diff --git a/pkg/network/ebpf/probes/probes.go b/pkg/network/ebpf/probes/probes.go index 8d8cfa52475420..4b8eaf28e01593 100644 --- a/pkg/network/ebpf/probes/probes.go +++ b/pkg/network/ebpf/probes/probes.go @@ -236,6 +236,7 @@ const ( ConnectionProtocolMap BPFMapName = "connection_protocol" // ConnectionTupleToSocketSKBConnMap is the map storing the connection tuple to socket skb conn tuple ConnectionTupleToSocketSKBConnMap BPFMapName = "conn_tuple_to_socket_skb_conn_tuple" + EnhancedTLSTagsMap BPFMapName = "tls_enhanced_tags" // ClassificationProgsMap is the map storing the programs to run on classification events ClassificationProgsMap BPFMapName = "classification_progs" // TCPCloseProgsMap is the map storing the programs to run on TCP close events diff --git a/pkg/network/tracer/connection/ebpf_tracer.go b/pkg/network/tracer/connection/ebpf_tracer.go index 4792e6f2aca5ae..97a86cdb5779bb 100644 --- a/pkg/network/tracer/connection/ebpf_tracer.go +++ b/pkg/network/tracer/connection/ebpf_tracer.go @@ -183,6 +183,7 @@ func newEbpfTracer(config *config.Config, _ telemetryComponent.Component) (Trace probes.PortBindingsMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.UDPPortBindingsMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.ConnectionProtocolMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, + probes.EnhancedTLSTagsMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.ConnectionTupleToSocketSKBConnMap: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, probes.TCPOngoingConnectPid: {MaxEntries: config.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries}, }, diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index d552f29b95bec1..cf83b1ac75fa21 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -11,6 +11,8 @@ import ( "bufio" "bytes" "context" + "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -1090,6 +1092,7 @@ func (s *TracerSuite) TestTCPEstablished() { require.True(t, ok) assert.Equal(t, uint32(0), conn.Last.TCPEstablished) assert.Equal(t, uint32(1), conn.Last.TCPClosed) + assert.Equal(t, uint32(1), conn.Tags[0]) } func (s *TracerSuite) TestTCPEstablishedPreExistingConn() { @@ -1407,3 +1410,247 @@ func findFailedConnectionByRemoteAddr(remoteAddr string, conns *network.Connecti } return network.FirstConnection(conns, failureFilter) } + +var serverCertPEM = []byte(`-----BEGIN CERTIFICATE----- +... (your server certificate here) ... +-----END CERTIFICATE-----`) + +var serverKeyPEM = []byte(`-----BEGIN RSA PRIVATE KEY----- +... (your server private key here) ... +-----END RSA PRIVATE KEY-----`) + +var clientCertPEM = []byte(`-----BEGIN CERTIFICATE----- +... (your client certificate here) ... +-----END CERTIFICATE-----`) + +var clientKeyPEM = []byte(`-----BEGIN RSA PRIVATE KEY----- +... (your client private key here) ... +-----END RSA PRIVATE KEY-----`) + +var caCertPEM = []byte(`-----BEGIN CERTIFICATE----- +... (your CA certificate here) ... +-----END CERTIFICATE-----`) + +func (s *TracerSuite) TestTLSConnection() { + t := s.T() + t.Skip() + + // Setup tracer with default configuration + cfg := testConfig() + tr := setupTracer(t, cfg) + + // Start a TLS server + cert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) + require.NoError(t, err, "failed to load server key pair") + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + ln, err := tls.Listen("tcp", "localhost:0", tlsConfig) + require.NoError(t, err, "failed to start TLS server") + defer ln.Close() + + serverAddr := ln.Addr().String() + + // Start server goroutine + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := ln.Accept() + if err != nil { + t.Log("Server accept error:", err) + return + } + defer conn.Close() + io.Copy(io.Discard, conn) + }() + + // Create a TLS client connection + clientTLSConfig := &tls.Config{ + InsecureSkipVerify: true, // For testing purposes only + } + conn, err := tls.Dial("tcp", serverAddr, clientTLSConfig) + require.NoError(t, err, "failed to connect to TLS server") + defer conn.Close() + + // Send some data + _, err = conn.Write([]byte("hello")) + require.NoError(t, err, "failed to write to TLS server") + + // Wait for the tracer to pick up the connection + var tracedConn *network.ConnectionStats + require.Eventually(t, func() bool { + conns := getConnections(t, tr) + laddr, raddr := conn.LocalAddr(), conn.RemoteAddr() + var ok bool + tracedConn, ok = findConnection(laddr, raddr, conns) + return ok + }, 3*time.Second, 100*time.Millisecond, "could not find TLS connection") + + require.NotNil(t, tracedConn, "traced connection is nil") + + // Check that the TLS tags are present + tags := tracedConn.Tags + assert.Contains(t, tags, "tls_version", "tls_version tag missing") + assert.Contains(t, tags, "tls_cipher_suite", "tls_cipher_suite tag missing") + + // Signal the server to finish + conn.Close() + <-done +} + +func (s *TracerSuite) TestTLSConnectionWithClientCert() { + t := s.T() + t.Skip() + + // Setup tracer with default configuration + cfg := testConfig() + tr := setupTracer(t, cfg) + + // Load server certificate and key + serverCert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) + require.NoError(t, err, "failed to load server key pair") + + // Load client certificate and key + clientCert, err := tls.X509KeyPair(clientCertPEM, clientKeyPEM) + require.NoError(t, err, "failed to load client key pair") + + // Create a certificate pool with the server's CA + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCertPEM) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + ln, err := tls.Listen("tcp", "localhost:0", tlsConfig) + require.NoError(t, err, "failed to start TLS server") + defer ln.Close() + + serverAddr := ln.Addr().String() + + // Start server goroutine + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := ln.Accept() + if err != nil { + t.Log("Server accept error:", err) + return + } + defer conn.Close() + io.Copy(io.Discard, conn) + }() + + // Create a TLS client connection + clientTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + InsecureSkipVerify: true, // For testing purposes only + } + conn, err := tls.Dial("tcp", serverAddr, clientTLSConfig) + require.NoError(t, err, "failed to connect to TLS server") + defer conn.Close() + + // Send some data + _, err = conn.Write([]byte("hello")) + require.NoError(t, err, "failed to write to TLS server") + + // Wait for the tracer to pick up the connection + var tracedConn *network.ConnectionStats + require.Eventually(t, func() bool { + conns := getConnections(t, tr) + laddr, raddr := conn.LocalAddr(), conn.RemoteAddr() + var ok bool + tracedConn, ok = findConnection(laddr, raddr, conns) + return ok + }, 3*time.Second, 100*time.Millisecond, "could not find TLS connection") + + require.NotNil(t, tracedConn, "traced connection is nil") + + // Check that the TLS tags are present + tags := tracedConn.Tags + assert.Contains(t, tags, "tls_version", "tls_version tag missing") + assert.Contains(t, tags, "tls_cipher_suite", "tls_cipher_suite tag missing") + + // Signal the server to finish + conn.Close() + <-done +} + +func (s *TracerSuite) TestTLSConnectionTLS12() { + t := s.T() + t.Skip() + + // Setup tracer with default configuration + cfg := testConfig() + tr := setupTracer(t, cfg) + + // Start a TLS server with TLS 1.2 + cert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) + require.NoError(t, err, "failed to load server key pair") + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS12, + } + + ln, err := tls.Listen("tcp", "localhost:0", tlsConfig) + require.NoError(t, err, "failed to start TLS server") + defer ln.Close() + + serverAddr := ln.Addr().String() + + // Start server goroutine + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := ln.Accept() + if err != nil { + t.Log("Server accept error:", err) + return + } + defer conn.Close() + io.Copy(io.Discard, conn) + }() + + // Create a TLS client connection with TLS 1.2 + clientTLSConfig := &tls.Config{ + InsecureSkipVerify: true, // For testing purposes only + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS12, + } + conn, err := tls.Dial("tcp", serverAddr, clientTLSConfig) + require.NoError(t, err, "failed to connect to TLS server") + defer conn.Close() + + // Send some data + _, err = conn.Write([]byte("hello")) + require.NoError(t, err, "failed to write to TLS server") + + // Wait for the tracer to pick up the connection + var tracedConn *network.ConnectionStats + require.Eventually(t, func() bool { + conns := getConnections(t, tr) + laddr, raddr := conn.LocalAddr(), conn.RemoteAddr() + var ok bool + tracedConn, ok = findConnection(laddr, raddr, conns) + return ok + }, 3*time.Second, 100*time.Millisecond, "could not find TLS connection") + + require.NotNil(t, tracedConn, "traced connection is nil") + + // Check that the TLS tags are present and correct + tags := tracedConn.Tags + assert.Contains(t, tags, "tls_version", "tls_version tag missing") + //assert.Equal(t, "TLSv1.2", tags["tls_version"], "expected TLS version 1.2") // TDOD: Fix this + + assert.Contains(t, tags, "tls_cipher_suite", "tls_cipher_suite tag missing") + + // Signal the server to finish + conn.Close() + <-done +} From 2768c3f4bb703057273b25322b1e2ca1bc83824f Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 17 Oct 2024 21:42:51 -0400 Subject: [PATCH 03/53] fix usm only map init --- pkg/network/usm/ebpf_main.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/network/usm/ebpf_main.go b/pkg/network/usm/ebpf_main.go index cc1bec4f89bf40..6372d6269a3630 100644 --- a/pkg/network/usm/ebpf_main.go +++ b/pkg/network/usm/ebpf_main.go @@ -401,6 +401,10 @@ func (e *ebpfProgram) init(buf bytecode.AssetReader, options manager.Options) er MaxEntries: e.cfg.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries, }, + probes.EnhancedTLSTagsMap: { + MaxEntries: e.cfg.MaxTrackedConnections, + EditorFlag: manager.EditMaxEntries, + }, tupleByPidFDMap: { MaxEntries: e.cfg.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries, From 8b3f4375e42c168bbe329bc600db2908e55b3f41 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Fri, 18 Oct 2024 16:02:53 -0400 Subject: [PATCH 04/53] rework tests and collection --- pkg/network/ebpf/c/protocols/tls/tls.h | 146 +++++---- pkg/network/tracer/testutil/tcp.go | 13 + pkg/network/tracer/tracer_test.go | 307 +++++------------- .../usm/tests/tracer_usm_linux_test.go | 16 +- 4 files changed, 176 insertions(+), 306 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 67aa6d7494f2fe..4d2622c91c9bea 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -67,7 +67,6 @@ static __always_inline bool is_valid_tls_version(__u16 version) { case TLS_VERSION13: return true; } - return false; } @@ -82,56 +81,85 @@ static __always_inline bool is_valid_tls_app_data(tls_record_header_t *hdr, __u3 return hdr->length + sizeof(tls_record_header_t) <= skb_len; } -static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 offset, __u16 extension_length) { +// parse_supported_versions_extension parses the supported_versions extension, extracting the list of +// supported versions for the client or the selected version for the server. +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 offset, __u16 extension_length, bool is_client_hello) { __u64 skb_len = skb->len; - // Ensure we have at least 1 byte for the list length - if (offset + 1 > skb_len) { - return false; - } + if (is_client_hello) { + // ClientHello Supported Versions Extension + // Ensure we have at least 1 byte for the list length + if (offset + 1 > skb_len) { + return false; + } - // Read Supported Versions Length (1 byte) - __u8 sv_list_length; - if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) { - return false; - } - offset += 1; + // Read Supported Versions Length (1 byte) + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) { + return false; + } + offset += 1; - // Ensure the list length is consistent with the extension length - if (sv_list_length + 1 > extension_length) { - return false; - } + // Ensure the list length is consistent with the extension length + if (sv_list_length + 1 > extension_length) { + return false; + } - // Ensure we don't read beyond the packet - if (offset + sv_list_length > skb_len) { - return false; - } + // Ensure we don't read beyond the packet + if (offset + sv_list_length > skb_len) { + return false; + } - // Set an upper bound for the loop to satisfy the eBPF verifier - #define MAX_SUPPORTED_VERSIONS 8 - __u8 versions_parsed = 0; + // Set an upper bound for the loop to satisfy the eBPF verifier + #define MAX_SUPPORTED_VERSIONS 8 + __u8 versions_parsed = 0; + + // Read the list of supported versions (2 bytes each) + for (__u8 i = 0; i + 1 < sv_list_length && versions_parsed < MAX_SUPPORTED_VERSIONS; i += 2, versions_parsed++) { + __u16 sv_version; + if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) { + return false; + } + sv_version = bpf_ntohs(sv_version); + offset += 2; + + if (sv_version == TLS_VERSION13) { + log_debug("adamk supported version (ClientHello): TLS 1.3"); + } - // Read the list of supported versions (2 bytes each) - for (__u8 i = 0; i + 1 < sv_list_length && versions_parsed < MAX_SUPPORTED_VERSIONS; i += 2, versions_parsed++) { - __u16 sv_version; - if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) { + // TODO: Store the supported version as needed + } + log_debug("adamk supported versions parsed (ClientHello): %d", versions_parsed); + return true; + } else { + // ServerHello Supported Versions Extension + // The extension length should be exactly 2 bytes + if (extension_length != 2) { return false; } - sv_version = bpf_ntohs(sv_version); - offset += 2; - if (sv_version == TLS_VERSION13) { - log_debug("adamk supported version: TLS 1.3"); + if (offset + 2 > skb_len) { + return false; } - // TODO: add supported version to the map + // Read Selected Version (2 bytes) + __u16 selected_version; + if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) { + return false; + } + selected_version = bpf_ntohs(selected_version); + if (selected_version == TLS_VERSION13) { + log_debug("adamk selected version (ServerHello): TLS 1.3"); + } + + // TODO: Store the selected version as needed + return true; } - log_debug("adamk supported versions parsed: %d", versions_parsed); - return true; } -static __always_inline bool parse_client_hello_extensions(struct __sk_buff *skb, __u64 offset, __u16 extensions_length) { +// parse_tls_extensions parses the TLS extensions in the ClientHello or ServerHello message. +static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 offset, __u16 extensions_length, bool is_client_hello) { __u64 skb_len = skb->len; __u64 extensions_end = offset + extensions_length; @@ -163,7 +191,7 @@ static __always_inline bool parse_client_hello_extensions(struct __sk_buff *skb, // Check for supported_versions extension (type 0x002B) if (extension_type == 0x002B) { - if (!parse_supported_versions_extension(skb, offset, extension_length)) { + if (!parse_supported_versions_extension(skb, offset, extension_length, is_client_hello)) { return false; } } @@ -176,6 +204,7 @@ static __always_inline bool parse_client_hello_extensions(struct __sk_buff *skb, return true; } +// parse_client_hello parses the ClientHello TLS payload. static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { __u32 skb_len = skb->len; @@ -265,8 +294,8 @@ static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct return false; } - // Parse Extensions - if (!parse_client_hello_extensions(skb, offset, extensions_length)) { + // Parse Extensions (is_client_hello = true) + if (!parse_tls_extensions(skb, offset, extensions_length, true /* is_client_hello */)) { return false; } @@ -274,6 +303,7 @@ static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct return true; } +// parse_server_hello parses the ServerHello TLS payload. static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { __u32 skb_len = skb->len; @@ -293,6 +323,8 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct if (offset + handshake_length > skb_len) return false; + __u64 handshake_end = offset + handshake_length; + // Read server version (2 bytes) __u16 server_version; if (bpf_skb_load_bytes(skb, offset, &server_version, sizeof(server_version)) < 0) @@ -301,9 +333,8 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct log_debug("adamk server version: %d", server_version); offset += 2; - // Validate server version - if (!is_valid_tls_version(server_version)) - return false; + // Note: In TLS 1.3, the server_version field is set to 0x0303 (TLS 1.2) + // The actual version is indicated in the supported_versions extension // Skip Random (32 bytes) offset += 32; @@ -318,7 +349,7 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct offset += session_id_len; // Ensure we don't read beyond the packet - if (offset + 2 > skb_len) + if (offset + 3 > skb_len) return false; // Read Cipher Suite (2 bytes) @@ -329,18 +360,17 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct log_debug("adamk server cipher_suite: %d", cipher_suite); offset += 2; - // You can store or process the cipher suite as needed - // Read Compression Method (1 byte) - if (offset + 1 > skb_len) - return false; __u8 compression_method; if (bpf_skb_load_bytes(skb, offset, &compression_method, sizeof(compression_method)) < 0) return false; offset += 1; - // Read Extensions Length (2 bytes) if present - if (offset + 2 <= skb_len) { + // Check if there are extensions + if (offset < handshake_end) { + // Read Extensions Length (2 bytes) + if (offset + 2 > skb_len || offset + 2 > handshake_end) + return false; __u16 extensions_length; if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) return false; @@ -348,16 +378,20 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct offset += 2; // Ensure we don't read beyond the packet - if (offset + extensions_length > skb_len) + if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) return false; - // Process extensions if needed + // Parse Extensions (is_client_hello = false) + if (!parse_tls_extensions(skb, offset, extensions_length, false /* is_client_hello */)) { + return false; + } } // At this point, we've successfully parsed the ServerHello message return true; } +// is_tls_handshake checks if the given TLS record is a TLS handshake message. static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { // Read handshake type __u8 handshake_type; @@ -366,16 +400,17 @@ static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __ // Only proceed if it's a ClientHello or ServerHello if (handshake_type == TLS_HANDSHAKE_CLIENT_HELLO) { - log_debug("adamk inspecting client hello"); + log_debug("adamk inspecting ClientHello"); return parse_client_hello(hdr, skb, offset); } else if (handshake_type == TLS_HANDSHAKE_SERVER_HELLO) { - log_debug("adamk inspecting server hello"); + log_debug("adamk inspecting ServerHello"); return parse_server_hello(hdr, skb, offset); } else { return false; } } +// is_tls checks if the given packet is a TLS packet. static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off) { __u32 skb_len = skb->len; @@ -392,12 +427,6 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off) { tls_hdr.version = bpf_ntohs(tls_hdr.version); tls_hdr.length = bpf_ntohs(tls_hdr.length); - // Validate version and length - if (!is_valid_tls_version(tls_hdr.version)) - return false; - if (tls_hdr.length > TLS_MAX_PAYLOAD_LENGTH) - return false; - // Validate version and length if (!is_valid_tls_version(tls_hdr.version)) return false; @@ -414,7 +443,6 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off) { // Handle based on content type switch (tls_hdr.content_type) { case TLS_HANDSHAKE: { - // return is_tls_handshake(&tls_hdr, skb, nh_off); bool handshake = is_tls_handshake(&tls_hdr, skb, nh_off); log_debug("adamk is_tls_handshake: %d", handshake); return handshake; diff --git a/pkg/network/tracer/testutil/tcp.go b/pkg/network/tracer/testutil/tcp.go index 9ef69afc27172c..0c4b749368aa1a 100644 --- a/pkg/network/tracer/testutil/tcp.go +++ b/pkg/network/tracer/testutil/tcp.go @@ -75,3 +75,16 @@ func (t *TCPServer) Shutdown() { t.ln = nil } } + +// GetFreePort returns a free port on localhost +func GetFreePort() (port uint16, err error) { + var a *net.TCPAddr + if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + defer l.Close() + return uint16(l.Addr().(*net.TCPAddr).Port), nil + } + } + return +} diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index cf83b1ac75fa21..36dc3afa805fa8 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -12,7 +12,6 @@ import ( "bytes" "context" "crypto/tls" - "crypto/x509" "encoding/json" "errors" "fmt" @@ -39,6 +38,9 @@ import ( "github.com/DataDog/datadog-agent/pkg/ebpf/ebpftest" "github.com/DataDog/datadog-agent/pkg/network" "github.com/DataDog/datadog-agent/pkg/network/config" + "github.com/DataDog/datadog-agent/pkg/network/protocols" + usmtestutil "github.com/DataDog/datadog-agent/pkg/network/protocols/http/testutil" + "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/kprobe" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns" "github.com/DataDog/datadog-agent/pkg/process/util" @@ -1055,7 +1057,6 @@ func (s *TracerSuite) TestDNSStats() { func (s *TracerSuite) TestTCPEstablished() { t := s.T() - // Ensure closed connections are flushed as soon as possible cfg := testConfig() tr := setupTracer(t, cfg) @@ -1092,7 +1093,6 @@ func (s *TracerSuite) TestTCPEstablished() { require.True(t, ok) assert.Equal(t, uint32(0), conn.Last.TCPEstablished) assert.Equal(t, uint32(1), conn.Last.TCPClosed) - assert.Equal(t, uint32(1), conn.Tags[0]) } func (s *TracerSuite) TestTCPEstablishedPreExistingConn() { @@ -1411,246 +1411,87 @@ func findFailedConnectionByRemoteAddr(remoteAddr string, conns *network.Connecti return network.FirstConnection(conns, failureFilter) } -var serverCertPEM = []byte(`-----BEGIN CERTIFICATE----- -... (your server certificate here) ... ------END CERTIFICATE-----`) - -var serverKeyPEM = []byte(`-----BEGIN RSA PRIVATE KEY----- -... (your server private key here) ... ------END RSA PRIVATE KEY-----`) - -var clientCertPEM = []byte(`-----BEGIN CERTIFICATE----- -... (your client certificate here) ... ------END CERTIFICATE-----`) - -var clientKeyPEM = []byte(`-----BEGIN RSA PRIVATE KEY----- -... (your client private key here) ... ------END RSA PRIVATE KEY-----`) - -var caCertPEM = []byte(`-----BEGIN CERTIFICATE----- -... (your CA certificate here) ... ------END CERTIFICATE-----`) - -func (s *TracerSuite) TestTLSConnection() { +func (s *TracerSuite) TestTLSClassification() { t := s.T() - t.Skip() - - // Setup tracer with default configuration cfg := testConfig() - tr := setupTracer(t, cfg) - - // Start a TLS server - cert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) - require.NoError(t, err, "failed to load server key pair") - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, + if !kprobe.ClassificationSupported(cfg) { + t.Skip("TLS classification platform not supported") } - ln, err := tls.Listen("tcp", "localhost:0", tlsConfig) - require.NoError(t, err, "failed to start TLS server") - defer ln.Close() - - serverAddr := ln.Addr().String() - - // Start server goroutine - done := make(chan struct{}) - go func() { - defer close(done) - conn, err := ln.Accept() - if err != nil { - t.Log("Server accept error:", err) - return - } - defer conn.Close() - io.Copy(io.Discard, conn) - }() - - // Create a TLS client connection - clientTLSConfig := &tls.Config{ - InsecureSkipVerify: true, // For testing purposes only - } - conn, err := tls.Dial("tcp", serverAddr, clientTLSConfig) - require.NoError(t, err, "failed to connect to TLS server") - defer conn.Close() - - // Send some data - _, err = conn.Write([]byte("hello")) - require.NoError(t, err, "failed to write to TLS server") - - // Wait for the tracer to pick up the connection - var tracedConn *network.ConnectionStats - require.Eventually(t, func() bool { - conns := getConnections(t, tr) - laddr, raddr := conn.LocalAddr(), conn.RemoteAddr() - var ok bool - tracedConn, ok = findConnection(laddr, raddr, conns) - return ok - }, 3*time.Second, 100*time.Millisecond, "could not find TLS connection") - - require.NotNil(t, tracedConn, "traced connection is nil") - - // Check that the TLS tags are present - tags := tracedConn.Tags - assert.Contains(t, tags, "tls_version", "tls_version tag missing") - assert.Contains(t, tags, "tls_cipher_suite", "tls_cipher_suite tag missing") - - // Signal the server to finish - conn.Close() - <-done -} - -func (s *TracerSuite) TestTLSConnectionWithClientCert() { - t := s.T() - t.Skip() + port, err := testutil.GetFreePort() + require.NoError(t, err) + portAsString := strconv.Itoa(int(port)) - // Setup tracer with default configuration - cfg := testConfig() tr := setupTracer(t, cfg) - // Load server certificate and key - serverCert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) - require.NoError(t, err, "failed to load server key pair") - - // Load client certificate and key - clientCert, err := tls.X509KeyPair(clientCertPEM, clientKeyPEM) - require.NoError(t, err, "failed to load client key pair") - - // Create a certificate pool with the server's CA - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCertPEM) - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{serverCert}, - ClientCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - } - - ln, err := tls.Listen("tcp", "localhost:0", tlsConfig) - require.NoError(t, err, "failed to start TLS server") - defer ln.Close() - - serverAddr := ln.Addr().String() - - // Start server goroutine - done := make(chan struct{}) - go func() { - defer close(done) - conn, err := ln.Accept() - if err != nil { - t.Log("Server accept error:", err) - return - } - defer conn.Close() - io.Copy(io.Discard, conn) - }() - - // Create a TLS client connection - clientTLSConfig := &tls.Config{ - Certificates: []tls.Certificate{clientCert}, - InsecureSkipVerify: true, // For testing purposes only + type tlsTest struct { + name string + postTracerSetup func(t *testing.T) + validation func(t *testing.T, tr *Tracer) } - conn, err := tls.Dial("tcp", serverAddr, clientTLSConfig) - require.NoError(t, err, "failed to connect to TLS server") - defer conn.Close() - - // Send some data - _, err = conn.Write([]byte("hello")) - require.NoError(t, err, "failed to write to TLS server") - - // Wait for the tracer to pick up the connection - var tracedConn *network.ConnectionStats - require.Eventually(t, func() bool { - conns := getConnections(t, tr) - laddr, raddr := conn.LocalAddr(), conn.RemoteAddr() - var ok bool - tracedConn, ok = findConnection(laddr, raddr, conns) - return ok - }, 3*time.Second, 100*time.Millisecond, "could not find TLS connection") - - require.NotNil(t, tracedConn, "traced connection is nil") - - // Check that the TLS tags are present - tags := tracedConn.Tags - assert.Contains(t, tags, "tls_version", "tls_version tag missing") - assert.Contains(t, tags, "tls_cipher_suite", "tls_cipher_suite tag missing") - - // Signal the server to finish - conn.Close() - <-done -} - -func (s *TracerSuite) TestTLSConnectionTLS12() { - t := s.T() - t.Skip() - - // Setup tracer with default configuration - cfg := testConfig() - tr := setupTracer(t, cfg) - - // Start a TLS server with TLS 1.2 - cert, err := tls.X509KeyPair(serverCertPEM, serverKeyPEM) - require.NoError(t, err, "failed to load server key pair") - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS12, + tests := make([]tlsTest, 0) + for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + scenario := scenario + tests = append(tests, tlsTest{ + name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), + postTracerSetup: func(t *testing.T) { + srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:"+portAsString, func(conn net.Conn) { + defer conn.Close() + // Echo back whatever is received + _, err := io.Copy(conn, conn) + if err != nil { + fmt.Printf("Failed to echo data: %v\n", err) + return + } + }, scenario) + done := make(chan struct{}) + require.NoError(t, srv.Run(done)) + t.Cleanup(func() { close(done) }) + tlsConfig := &tls.Config{ + MinVersion: scenario, + MaxVersion: scenario, + InsecureSkipVerify: true, + } + conn, err := net.Dial("tcp", "localhost:"+portAsString) + require.NoError(t, err) + defer conn.Close() + + // Wrap the TCP connection with TLS + tlsConn := tls.Client(conn, tlsConfig) + + // Perform the TLS handshake + require.NoError(t, tlsConn.Handshake()) + }, + validation: func(t *testing.T, tr *Tracer) { + // Iterate through active connections until we find connection created above + require.Eventuallyf(t, func() bool { + payload := getConnections(t, tr) + for _, c := range payload.Conns { + if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) { + return true + } + } + return false + }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", portAsString) + }, + }) } - ln, err := tls.Listen("tcp", "localhost:0", tlsConfig) - require.NoError(t, err, "failed to start TLS server") - defer ln.Close() - - serverAddr := ln.Addr().String() - - // Start server goroutine - done := make(chan struct{}) - go func() { - defer close(done) - conn, err := ln.Accept() - if err != nil { - t.Log("Server accept error:", err) - return - } - defer conn.Close() - io.Copy(io.Discard, conn) - }() - - // Create a TLS client connection with TLS 1.2 - clientTLSConfig := &tls.Config{ - InsecureSkipVerify: true, // For testing purposes only - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS12, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if ebpftest.GetBuildMode() == ebpftest.Fentry { + t.Skip("protocol classification not supported for fentry tracer") + } + t.Cleanup(func() { tr.RemoveClient(clientID) }) + t.Cleanup(func() { _ = tr.Pause() }) + + tr.RemoveClient(clientID) + require.NoError(t, tr.RegisterClient(clientID)) + require.NoError(t, tr.Resume(), "enable probes - before post tracer") + tt.postTracerSetup(t) + require.NoError(t, tr.Pause(), "disable probes - after post tracer") + tt.validation(t, tr) + }) } - conn, err := tls.Dial("tcp", serverAddr, clientTLSConfig) - require.NoError(t, err, "failed to connect to TLS server") - defer conn.Close() - - // Send some data - _, err = conn.Write([]byte("hello")) - require.NoError(t, err, "failed to write to TLS server") - - // Wait for the tracer to pick up the connection - var tracedConn *network.ConnectionStats - require.Eventually(t, func() bool { - conns := getConnections(t, tr) - laddr, raddr := conn.LocalAddr(), conn.RemoteAddr() - var ok bool - tracedConn, ok = findConnection(laddr, raddr, conns) - return ok - }, 3*time.Second, 100*time.Millisecond, "could not find TLS connection") - - require.NotNil(t, tracedConn, "traced connection is nil") - - // Check that the TLS tags are present and correct - tags := tracedConn.Tags - assert.Contains(t, tags, "tls_version", "tls_version tag missing") - //assert.Equal(t, "TLSv1.2", tags["tls_version"], "expected TLS version 1.2") // TDOD: Fix this - - assert.Contains(t, tags, "tls_cipher_suite", "tls_cipher_suite tag missing") - - // Signal the server to finish - conn.Close() - <-done } diff --git a/pkg/network/usm/tests/tracer_usm_linux_test.go b/pkg/network/usm/tests/tracer_usm_linux_test.go index 825e4203991939..a22ffe89f3e28b 100644 --- a/pkg/network/usm/tests/tracer_usm_linux_test.go +++ b/pkg/network/usm/tests/tracer_usm_linux_test.go @@ -282,18 +282,6 @@ func testProtocolConnectionProtocolMapCleanup(t *testing.T, tr *tracer.Tracer, c }) } -func getFreePort() (port uint16, err error) { - var a *net.TCPAddr - if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { - var l *net.TCPListener - if l, err = net.ListenTCP("tcp", a); err == nil { - defer l.Close() - return uint16(l.Addr().(*net.TCPAddr).Port), nil - } - } - return -} - func (s *USMSuite) TestIgnoreTLSClassificationIfApplicationProtocolWasDetected() { t := s.T() cfg := tracertestutil.Config() @@ -391,7 +379,7 @@ func (s *USMSuite) TestIgnoreTLSClassificationIfApplicationProtocolWasDetected() } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - clientPort, err := getFreePort() + clientPort, err := tracertestutil.GetFreePort() require.NoError(t, err) dialer := &net.Dialer{ LocalAddr: &net.TCPAddr{ @@ -451,7 +439,7 @@ func (s *USMSuite) TestTLSClassification() { t.Skip("TLS classification platform not supported") } - port, err := getFreePort() + port, err := tracertestutil.GetFreePort() require.NoError(t, err) portAsString := strconv.Itoa(int(port)) From 69ff889d6a79890d65455764bc7ff1f917eab36f Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 21 Oct 2024 11:21:10 -0400 Subject: [PATCH 05/53] store tls tags in map --- .../classification/protocol-classification.h | 2 +- .../classification/shared-tracer-maps.h | 16 ++++++ pkg/network/ebpf/c/protocols/tls/tls.h | 54 +++++++++++++------ 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index a2641ff91cc127..7f6cd1206de158 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -169,7 +169,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct protocol_t app_layer_proto = get_protocol_from_stack(protocol_stack, LAYER_APPLICATION); - if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off)) { + if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, &skb_tup)) { // TLS classification update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); // The connection is TLS encrypted, thus we cannot classify the protocol diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index 2f6da79cbbcae1..cb9d94a5e25c8a 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -24,6 +24,22 @@ static __always_inline tls_enhanced_tags_t* get_tls_enhanced_tags(conn_tuple_t* return bpf_map_lookup_elem(&tls_enhanced_tags, tuple); } +static __always_inline tls_enhanced_tags_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { + conn_tuple_t normalized_tup = *tuple; + normalize_tuple(&normalized_tup); + + tls_enhanced_tags_t *tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + if (!tags) { + // Initialize a new entry + tls_enhanced_tags_t empty_tags = {0}; + bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags, BPF_NOEXIST); + + // Lookup again + tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + } + return tags; +} + static __always_inline protocol_stack_t* get_protocol_stack(conn_tuple_t *skb_tup) { conn_tuple_t normalized_tup = *skb_tup; normalize_tuple(&normalized_tup); diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 4d2622c91c9bea..690c391d9ae271 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -3,6 +3,7 @@ #include "ktypes.h" #include "bpf_builtins.h" +#include "shared-tracer-maps.h" #include "ip.h" @@ -18,6 +19,8 @@ #define TLS_HANDSHAKE 0x16 #define TLS_APPLICATION_DATA 0x17 +#define TLS_EXTENSION_SUPPORTED_VERSIONS 0x002b + /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) @@ -83,7 +86,7 @@ static __always_inline bool is_valid_tls_app_data(tls_record_header_t *hdr, __u3 // parse_supported_versions_extension parses the supported_versions extension, extracting the list of // supported versions for the client or the selected version for the server. -static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 offset, __u16 extension_length, bool is_client_hello) { +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 offset, __u16 extension_length, bool is_client_hello, conn_tuple_t *tup) { __u64 skb_len = skb->len; if (is_client_hello) { @@ -110,6 +113,11 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff return false; } + tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(tup); + if (!tags) { + return false; + } + // Set an upper bound for the loop to satisfy the eBPF verifier #define MAX_SUPPORTED_VERSIONS 8 __u8 versions_parsed = 0; @@ -126,8 +134,7 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff if (sv_version == TLS_VERSION13) { log_debug("adamk supported version (ClientHello): TLS 1.3"); } - - // TODO: Store the supported version as needed + tags->client_tags.offered_versions[versions_parsed] = sv_version; } log_debug("adamk supported versions parsed (ClientHello): %d", versions_parsed); return true; @@ -149,17 +156,23 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff } selected_version = bpf_ntohs(selected_version); + tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(tup); + if (!tags) { + return false; + } + + tags->server_tags.version = selected_version; + if (selected_version == TLS_VERSION13) { log_debug("adamk selected version (ServerHello): TLS 1.3"); } - // TODO: Store the selected version as needed return true; } } // parse_tls_extensions parses the TLS extensions in the ClientHello or ServerHello message. -static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 offset, __u16 extensions_length, bool is_client_hello) { +static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 offset, __u16 extensions_length, bool is_client_hello, conn_tuple_t *tup) { __u64 skb_len = skb->len; __u64 extensions_end = offset + extensions_length; @@ -189,9 +202,9 @@ static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 of return false; } - // Check for supported_versions extension (type 0x002B) - if (extension_type == 0x002B) { - if (!parse_supported_versions_extension(skb, offset, extension_length, is_client_hello)) { + // Check for supported_versions extension (0x002B) + if (extension_type == TLS_EXTENSION_SUPPORTED_VERSIONS) { + if (!parse_supported_versions_extension(skb, offset, extension_length, is_client_hello, tup)) { return false; } } @@ -205,7 +218,7 @@ static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 of } // parse_client_hello parses the ClientHello TLS payload. -static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { +static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset, conn_tuple_t *tup) { __u32 skb_len = skb->len; // Move offset past handshake type (1 byte) @@ -295,7 +308,7 @@ static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct } // Parse Extensions (is_client_hello = true) - if (!parse_tls_extensions(skb, offset, extensions_length, true /* is_client_hello */)) { + if (!parse_tls_extensions(skb, offset, extensions_length, true, tup)) { return false; } @@ -304,7 +317,7 @@ static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct } // parse_server_hello parses the ServerHello TLS payload. -static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { +static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset, conn_tuple_t *tup) { __u32 skb_len = skb->len; // Move offset past handshake type (1 byte) @@ -360,11 +373,18 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct log_debug("adamk server cipher_suite: %d", cipher_suite); offset += 2; + tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(tup); + if (!tags) { + return false; + } + tags->server_tags.cipher_suite = cipher_suite; + // Read Compression Method (1 byte) __u8 compression_method; if (bpf_skb_load_bytes(skb, offset, &compression_method, sizeof(compression_method)) < 0) return false; offset += 1; + tags->server_tags.compression_method = compression_method; // Check if there are extensions if (offset < handshake_end) { @@ -382,7 +402,7 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct return false; // Parse Extensions (is_client_hello = false) - if (!parse_tls_extensions(skb, offset, extensions_length, false /* is_client_hello */)) { + if (!parse_tls_extensions(skb, offset, extensions_length, false, tup)) { return false; } } @@ -392,7 +412,7 @@ static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct } // is_tls_handshake checks if the given TLS record is a TLS handshake message. -static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset) { +static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset, conn_tuple_t *tup) { // Read handshake type __u8 handshake_type; if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) @@ -401,17 +421,17 @@ static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __ // Only proceed if it's a ClientHello or ServerHello if (handshake_type == TLS_HANDSHAKE_CLIENT_HELLO) { log_debug("adamk inspecting ClientHello"); - return parse_client_hello(hdr, skb, offset); + return parse_client_hello(hdr, skb, offset, tup); } else if (handshake_type == TLS_HANDSHAKE_SERVER_HELLO) { log_debug("adamk inspecting ServerHello"); - return parse_server_hello(hdr, skb, offset); + return parse_server_hello(hdr, skb, offset, tup); } else { return false; } } // is_tls checks if the given packet is a TLS packet. -static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off) { +static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, conn_tuple_t *tup) { __u32 skb_len = skb->len; // Ensure there's enough space for TLS record header @@ -443,7 +463,7 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off) { // Handle based on content type switch (tls_hdr.content_type) { case TLS_HANDSHAKE: { - bool handshake = is_tls_handshake(&tls_hdr, skb, nh_off); + bool handshake = is_tls_handshake(&tls_hdr, skb, nh_off, tup); log_debug("adamk is_tls_handshake: %d", handshake); return handshake; } From 3648687f7cb7f89a44b471a7054366dcdc7fd321 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 22 Oct 2024 12:35:57 -0400 Subject: [PATCH 06/53] fix processing of server hello packets --- .../classification/protocol-classification.h | 27 +- pkg/network/ebpf/c/protocols/tls/tls.h | 515 ++++++++---------- pkg/network/ebpf/c/tracer/events.h | 1 + pkg/network/tracer/tracer_test.go | 17 +- 4 files changed, 248 insertions(+), 312 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 7f6cd1206de158..76cef90250503e 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -158,7 +158,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct return; } - if (is_fully_classified(protocol_stack) || is_protocol_layer_known(protocol_stack, LAYER_ENCRYPTION)) { + if (is_fully_classified(protocol_stack) ) { return; } @@ -169,11 +169,30 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct protocol_t app_layer_proto = get_protocol_from_stack(protocol_stack, LAYER_APPLICATION); - if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, &skb_tup)) { + tls_record_header_t tls_hdr = {0}; + + if (is_tls(skb, skb_info.data_off, &tls_hdr)) { // TLS classification update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); - // The connection is TLS encrypted, thus we cannot classify the protocol - // using the socket filter and therefore we can bail out; + + // Parse TLS payload + tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(&skb_tup); + if (tags) { + // Parse the TLS payload and update the tags + int ret = parse_tls_payload(skb, skb_info.data_off, &tls_hdr, tags); + log_debug("adamk\n"); + log_debug("adamk tls classification: parse_tls_payload=%d", ret); + ret++; + log_debug("adamk tls classification: client version 1=%d", tags->client_tags.offered_versions[0]); + log_debug("adamk tls classification: server version=%d", tags->server_tags.version); + log_debug("adamk tls classification: server cipher=%d", tags->server_tags.cipher_suite); + } + // The connection is TLS encrypted, thus we cannot further classify the protocol + // using the socket filter and can bail out; + return; + } + + if(is_protocol_layer_known(protocol_stack, LAYER_ENCRYPTION)) { return; } diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 690c391d9ae271..41ff9657f70504 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -3,8 +3,6 @@ #include "ktypes.h" #include "bpf_builtins.h" -#include "shared-tracer-maps.h" - #include "ip.h" #define ETH_HLEN 14 // Ethernet header length @@ -19,8 +17,6 @@ #define TLS_HANDSHAKE 0x16 #define TLS_APPLICATION_DATA 0x17 -#define TLS_EXTENSION_SUPPORTED_VERSIONS 0x002b - /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) @@ -32,20 +28,15 @@ typedef struct { __u16 length; } __attribute__((packed)) tls_record_header_t; -typedef struct { - __u8 handshake_type; - __u8 length[3]; - __u16 version; -} __attribute__((packed)) tls_hello_message_t; - +// TLS enhanced tags structures typedef struct { __u16 offered_versions[6]; + __u8 num_offered_versions; } tls_client_tags_t; typedef struct { __u16 version; __u16 cipher_suite; - __u8 compression_method; } tls_server_tags_t; typedef struct { @@ -55,423 +46,345 @@ typedef struct { #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 -// The size of the handshake type and the length. -#define TLS_HELLO_MESSAGE_HEADER_SIZE 4 -// is_valid_tls_version checks if the given version is a valid TLS version as -// defined in the TLS specification. +// Function to check if the given version is a valid TLS version static __always_inline bool is_valid_tls_version(__u16 version) { switch (version) { - case SSL_VERSION20: - case SSL_VERSION30: - case TLS_VERSION10: - case TLS_VERSION11: - case TLS_VERSION12: - case TLS_VERSION13: - return true; + case SSL_VERSION20: + case SSL_VERSION30: + case TLS_VERSION10: + case TLS_VERSION11: + case TLS_VERSION12: + case TLS_VERSION13: + return true; + default: + return false; } - return false; } -// is_valid_tls_app_data checks if the buffer is a valid TLS Application Data -// record header. The record header is considered valid if: -// - the TLS version field is a known SSL/TLS version -// - the payload length is below the maximum payload length defined in the -// standard. -// - the payload length + the size of the record header is less than the size -// of the skb -static __always_inline bool is_valid_tls_app_data(tls_record_header_t *hdr, __u32 skb_len) { - return hdr->length + sizeof(tls_record_header_t) <= skb_len; -} - -// parse_supported_versions_extension parses the supported_versions extension, extracting the list of -// supported versions for the client or the selected version for the server. -static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 offset, __u16 extension_length, bool is_client_hello, conn_tuple_t *tup) { - __u64 skb_len = skb->len; - - if (is_client_hello) { - // ClientHello Supported Versions Extension - // Ensure we have at least 1 byte for the list length - if (offset + 1 > skb_len) { - return false; - } - - // Read Supported Versions Length (1 byte) - __u8 sv_list_length; - if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) { - return false; - } - offset += 1; - - // Ensure the list length is consistent with the extension length - if (sv_list_length + 1 > extension_length) { - return false; - } - - // Ensure we don't read beyond the packet - if (offset + sv_list_length > skb_len) { - return false; - } - - tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(tup); - if (!tags) { - return false; - } - - // Set an upper bound for the loop to satisfy the eBPF verifier - #define MAX_SUPPORTED_VERSIONS 8 - __u8 versions_parsed = 0; - - // Read the list of supported versions (2 bytes each) - for (__u8 i = 0; i + 1 < sv_list_length && versions_parsed < MAX_SUPPORTED_VERSIONS; i += 2, versions_parsed++) { - __u16 sv_version; - if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) { - return false; - } - sv_version = bpf_ntohs(sv_version); - offset += 2; - - if (sv_version == TLS_VERSION13) { - log_debug("adamk supported version (ClientHello): TLS 1.3"); - } - tags->client_tags.offered_versions[versions_parsed] = sv_version; - } - log_debug("adamk supported versions parsed (ClientHello): %d", versions_parsed); - return true; - } else { - // ServerHello Supported Versions Extension - // The extension length should be exactly 2 bytes - if (extension_length != 2) { - return false; - } +// Helper function to read and validate the TLS record header +static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { + __u32 skb_len = skb->len; - if (offset + 2 > skb_len) { - return false; - } + // Ensure there's enough space for TLS record header + if (nh_off + sizeof(tls_record_header_t) > skb_len) + return false; - // Read Selected Version (2 bytes) - __u16 selected_version; - if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) { - return false; - } - selected_version = bpf_ntohs(selected_version); + // Read TLS record header + if (bpf_skb_load_bytes(skb, nh_off, tls_hdr, sizeof(tls_record_header_t)) < 0) + return false; - tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(tup); - if (!tags) { - return false; - } + // Convert fields to host byte order + tls_hdr->version = bpf_ntohs(tls_hdr->version); + tls_hdr->length = bpf_ntohs(tls_hdr->length); - tags->server_tags.version = selected_version; + // Validate version and length + if (!is_valid_tls_version(tls_hdr->version)) + return false; + if (tls_hdr->length > TLS_MAX_PAYLOAD_LENGTH) + return false; - if (selected_version == TLS_VERSION13) { - log_debug("adamk selected version (ServerHello): TLS 1.3"); - } + // Ensure we don't read beyond the packet + if (nh_off + sizeof(tls_record_header_t) + tls_hdr->length > skb_len) + return false; - return true; - } + return true; } -// parse_tls_extensions parses the TLS extensions in the ClientHello or ServerHello message. -static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 offset, __u16 extensions_length, bool is_client_hello, conn_tuple_t *tup) { - __u64 skb_len = skb->len; - __u64 extensions_end = offset + extensions_length; - - // Set an upper bound for the loop to satisfy the eBPF verifier - #define MAX_EXTENSIONS 16 - __u8 extensions_parsed = 0; - - while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { - // Read Extension Type (2 bytes) - __u16 extension_type; - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) { - return false; - } - extension_type = bpf_ntohs(extension_type); - offset += 2; - - // Read Extension Length (2 bytes) - __u16 extension_length; - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) { - return false; - } - extension_length = bpf_ntohs(extension_length); - offset += 2; - - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) { - return false; - } - - // Check for supported_versions extension (0x002B) - if (extension_type == TLS_EXTENSION_SUPPORTED_VERSIONS) { - if (!parse_supported_versions_extension(skb, offset, extension_length, is_client_hello, tup)) { - return false; - } - } +// Function to check if the packet is TLS +static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { + // Use the helper function to read and validate the TLS record header + if (!read_tls_record_header(skb, nh_off, tls_hdr)) + return false; - // Skip to the next extension - offset += extension_length; - extensions_parsed++; - } + // Validate content type + if (tls_hdr->content_type != TLS_HANDSHAKE && tls_hdr->content_type != TLS_APPLICATION_DATA) + return false; return true; } -// parse_client_hello parses the ClientHello TLS payload. -static __always_inline bool parse_client_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset, conn_tuple_t *tup) { - __u32 skb_len = skb->len; - +// Function to parse ClientHello message +static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_enhanced_tags_t *tags) { // Move offset past handshake type (1 byte) offset += 1; // Read handshake length (3 bytes) __u8 handshake_length_bytes[3]; if (bpf_skb_load_bytes(skb, offset, handshake_length_bytes, 3) < 0) - return false; + return -1; __u32 handshake_length = (handshake_length_bytes[0] << 16) | (handshake_length_bytes[1] << 8) | - handshake_length_bytes[2]; + (handshake_length_bytes[2]); offset += 3; // Ensure we don't read beyond the packet if (offset + handshake_length > skb_len) - return false; + return -1; - // Read client version (2 bytes) + // Read client_version (2 bytes) __u16 client_version; if (bpf_skb_load_bytes(skb, offset, &client_version, sizeof(client_version)) < 0) - return false; + return -1; client_version = bpf_ntohs(client_version); - log_debug("adamk client version: %d", client_version); offset += 2; - // Validate client version - if (!is_valid_tls_version(client_version)) - return false; - // Skip Random (32 bytes) offset += 32; // Read Session ID Length (1 byte) - __u8 session_id_len; - if (bpf_skb_load_bytes(skb, offset, &session_id_len, sizeof(session_id_len)) < 0) - return false; + __u8 session_id_length; + if (bpf_skb_load_bytes(skb, offset, &session_id_length, sizeof(session_id_length)) < 0) + return -1; offset += 1; // Skip Session ID - offset += session_id_len; - - // Ensure we don't read beyond the packet - if (offset + 2 > skb_len) - return false; + offset += session_id_length; // Read Cipher Suites Length (2 bytes) __u16 cipher_suites_length; if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, sizeof(cipher_suites_length)) < 0) - return false; + return -1; cipher_suites_length = bpf_ntohs(cipher_suites_length); - log_debug("adamk client cipher_suites_length: %d", cipher_suites_length); offset += 2; - // Ensure we don't read beyond the packet - if (offset + cipher_suites_length > skb_len) - return false; - // Skip Cipher Suites offset += cipher_suites_length; // Read Compression Methods Length (1 byte) - if (offset + 1 > skb_len) - return false; __u8 compression_methods_length; if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, sizeof(compression_methods_length)) < 0) - return false; + return -1; offset += 1; // Skip Compression Methods offset += compression_methods_length; + // Check if extensions are present + if (offset + 2 > skb_len) + return -1; + // Read Extensions Length (2 bytes) - if (offset + 2 > skb_len) { - return false; - } __u16 extensions_length; - if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) { - return false; - } + if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) + return -1; extensions_length = bpf_ntohs(extensions_length); offset += 2; // Ensure we don't read beyond the packet - if (offset + extensions_length > skb_len) { - return false; - } + if (offset + extensions_length > skb_len) + return -1; - // Parse Extensions (is_client_hello = true) - if (!parse_tls_extensions(skb, offset, extensions_length, true, tup)) { - return false; + // Parse Extensions + __u64 extensions_end = offset + extensions_length; + #define MAX_EXTENSIONS 16 + __u8 extensions_parsed = 0; + + while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { + // Read Extension Type (2 bytes) + __u16 extension_type; + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) + return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; + + // Read Extension Length (2 bytes) + __u16 extension_length; + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) + return -1; + extension_length = bpf_ntohs(extension_length); + offset += 2; + + // Ensure we don't read beyond the packet + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) + return -1; + + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == 0x002B) { + // Parse supported_versions extension + if (offset + 1 > skb_len) + return -1; + + // Read list length (1 byte) + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) + return -1; + offset += 1; + + // Ensure we don't read beyond the packet + if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) + return -1; + + // Parse versions + __u8 num_versions = 0; + #define MAX_SUPPORTED_VERSIONS 6 + for (__u8 i = 0; i + 1 < sv_list_length && num_versions < MAX_SUPPORTED_VERSIONS; i += 2, num_versions++) { + __u16 sv_version; + if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) + return -1; + sv_version = bpf_ntohs(sv_version); + offset += 2; + + tags->client_tags.offered_versions[num_versions] = sv_version; + } + tags->client_tags.num_offered_versions = num_versions; + } else { + // Skip other extensions + offset += extension_length; + } + + extensions_parsed++; } - // At this point, we've successfully parsed the ClientHello message - return true; + return 0; } -// parse_server_hello parses the ServerHello TLS payload. -static __always_inline bool parse_server_hello(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset, conn_tuple_t *tup) { - __u32 skb_len = skb->len; - +// Function to parse ServerHello message +static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_enhanced_tags_t *tags) { // Move offset past handshake type (1 byte) offset += 1; // Read handshake length (3 bytes) __u8 handshake_length_bytes[3]; if (bpf_skb_load_bytes(skb, offset, handshake_length_bytes, 3) < 0) - return false; + return -1; __u32 handshake_length = (handshake_length_bytes[0] << 16) | (handshake_length_bytes[1] << 8) | - handshake_length_bytes[2]; + (handshake_length_bytes[2]); offset += 3; // Ensure we don't read beyond the packet if (offset + handshake_length > skb_len) - return false; + return -1; __u64 handshake_end = offset + handshake_length; - // Read server version (2 bytes) + // Read server_version (2 bytes) __u16 server_version; if (bpf_skb_load_bytes(skb, offset, &server_version, sizeof(server_version)) < 0) - return false; + return -1; server_version = bpf_ntohs(server_version); - log_debug("adamk server version: %d", server_version); - offset += 2; - + // Set the version here and try to get the "real" version from the extensions // Note: In TLS 1.3, the server_version field is set to 0x0303 (TLS 1.2) - // The actual version is indicated in the supported_versions extension + // The actual version is embedded in the supported_versions extension + tags->server_tags.version = server_version; + offset += 2; // Skip Random (32 bytes) offset += 32; // Read Session ID Length (1 byte) - __u8 session_id_len; - if (bpf_skb_load_bytes(skb, offset, &session_id_len, sizeof(session_id_len)) < 0) - return false; + __u8 session_id_length; + if (bpf_skb_load_bytes(skb, offset, &session_id_length, sizeof(session_id_length)) < 0) + return -1; offset += 1; // Skip Session ID - offset += session_id_len; - - // Ensure we don't read beyond the packet - if (offset + 3 > skb_len) - return false; + offset += session_id_length; // Read Cipher Suite (2 bytes) __u16 cipher_suite; if (bpf_skb_load_bytes(skb, offset, &cipher_suite, sizeof(cipher_suite)) < 0) - return false; + return -1; cipher_suite = bpf_ntohs(cipher_suite); - log_debug("adamk server cipher_suite: %d", cipher_suite); offset += 2; - tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(tup); - if (!tags) { - return false; - } - tags->server_tags.cipher_suite = cipher_suite; - - // Read Compression Method (1 byte) - __u8 compression_method; - if (bpf_skb_load_bytes(skb, offset, &compression_method, sizeof(compression_method)) < 0) - return false; + // Skip Compression Method (1 byte) offset += 1; - tags->server_tags.compression_method = compression_method; + + // Store parsed data into tags + tags->server_tags.cipher_suite = cipher_suite; // Check if there are extensions if (offset < handshake_end) { // Read Extensions Length (2 bytes) if (offset + 2 > skb_len || offset + 2 > handshake_end) - return false; + return -1; __u16 extensions_length; if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) - return false; + return -1; extensions_length = bpf_ntohs(extensions_length); offset += 2; // Ensure we don't read beyond the packet if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) - return false; - - // Parse Extensions (is_client_hello = false) - if (!parse_tls_extensions(skb, offset, extensions_length, false, tup)) { - return false; - } - } + return -1; + + // Parse Extensions + __u64 extensions_end = offset + extensions_length; + #define MAX_EXTENSIONS 16 + __u8 extensions_parsed = 0; + + while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { + // Read Extension Type (2 bytes) + __u16 extension_type; + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) + return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; - // At this point, we've successfully parsed the ServerHello message - return true; -} + // Read Extension Length (2 bytes) + __u16 extension_length; + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) + return -1; + extension_length = bpf_ntohs(extension_length); + offset += 2; -// is_tls_handshake checks if the given TLS record is a TLS handshake message. -static __always_inline bool is_tls_handshake(tls_record_header_t *hdr, struct __sk_buff *skb, __u64 offset, conn_tuple_t *tup) { - // Read handshake type - __u8 handshake_type; - if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) - return false; + // Ensure we don't read beyond the packet + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) + return -1; + + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == 0x002B) { + // Parse supported_versions extension + if (extension_length != 2) + return -1; + + if (offset + 2 > skb_len) + return -1; + + // Read selected version (2 bytes) + __u16 selected_version; + if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) + return -1; + selected_version = bpf_ntohs(selected_version); + offset += 2; + + tags->server_tags.version = selected_version; + } else { + // Skip other extensions + offset += extension_length; + } - // Only proceed if it's a ClientHello or ServerHello - if (handshake_type == TLS_HANDSHAKE_CLIENT_HELLO) { - log_debug("adamk inspecting ClientHello"); - return parse_client_hello(hdr, skb, offset, tup); - } else if (handshake_type == TLS_HANDSHAKE_SERVER_HELLO) { - log_debug("adamk inspecting ServerHello"); - return parse_server_hello(hdr, skb, offset, tup); - } else { - return false; + extensions_parsed++; + } } -} - -// is_tls checks if the given packet is a TLS packet. -static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, conn_tuple_t *tup) { - __u32 skb_len = skb->len; - - // Ensure there's enough space for TLS record header - if (nh_off + sizeof(tls_record_header_t) > skb_len) - return false; - - // Read TLS record header - tls_record_header_t tls_hdr; - if (bpf_skb_load_bytes(skb, nh_off, &tls_hdr, sizeof(tls_hdr)) < 0) - return false; - - // Convert fields to host byte order - tls_hdr.version = bpf_ntohs(tls_hdr.version); - tls_hdr.length = bpf_ntohs(tls_hdr.length); - - // Validate version and length - if (!is_valid_tls_version(tls_hdr.version)) - return false; - if (tls_hdr.length > TLS_MAX_PAYLOAD_LENGTH) - return false; - - // Move offset to the start of TLS handshake message - nh_off += sizeof(tls_record_header_t); - // Ensure we don't read beyond the packet - if (nh_off + tls_hdr.length > skb_len) - return false; + return 0; +} - // Handle based on content type - switch (tls_hdr.content_type) { - case TLS_HANDSHAKE: { - bool handshake = is_tls_handshake(&tls_hdr, skb, nh_off, tup); - log_debug("adamk is_tls_handshake: %d", handshake); - return handshake; +// Function to parse the TLS payload and update tls_enhanced_tags_t +static __always_inline int parse_tls_payload(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr, tls_enhanced_tags_t *tags) { + // At this point, tls_hdr has already been validated and filled by is_tls() + __u64 offset = nh_off + sizeof(tls_record_header_t); + + if (tls_hdr->content_type == TLS_HANDSHAKE) { + __u8 handshake_type; + if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) + return -1; + + if (handshake_type == TLS_HANDSHAKE_CLIENT_HELLO) { + log_debug("adamk tls classification: client hello"); + return parse_client_hello(skb, offset, skb->len, tags); + } else if (handshake_type == TLS_HANDSHAKE_SERVER_HELLO) { + log_debug("adamk tls classification: server hello"); + return parse_server_hello(skb, offset, skb->len, tags); + } else { + return -1; } - case TLS_APPLICATION_DATA: - return is_valid_tls_app_data(&tls_hdr, skb_len); - default: - return false; + } else { + return -1; } } -#endif +#endif // __TLS_H diff --git a/pkg/network/ebpf/c/tracer/events.h b/pkg/network/ebpf/c/tracer/events.h index 5e0650faa862fd..778070c28ec8e8 100644 --- a/pkg/network/ebpf/c/tracer/events.h +++ b/pkg/network/ebpf/c/tracer/events.h @@ -23,6 +23,7 @@ static __always_inline void clean_protocol_classification(conn_tuple_t *tup) { conn_tuple.netns = 0; normalize_tuple(&conn_tuple); delete_protocol_stack(&conn_tuple, NULL, FLAG_TCP_CLOSE_DELETION); + // TODO: delete TLS enhanced tags conn_tuple_t *skb_tup_ptr = bpf_map_lookup_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple); if (skb_tup_ptr == NULL) { diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index 36dc3afa805fa8..31ae680e0a0164 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -1418,9 +1418,9 @@ func (s *TracerSuite) TestTLSClassification() { if !kprobe.ClassificationSupported(cfg) { t.Skip("TLS classification platform not supported") } - - port, err := testutil.GetFreePort() - require.NoError(t, err) + //testutil.GetFreePort() + port := uint16(443) + //require.NoError(t, err) portAsString := strconv.Itoa(int(port)) tr := setupTracer(t, cfg) @@ -1431,7 +1431,8 @@ func (s *TracerSuite) TestTLSClassification() { validation func(t *testing.T, tr *Tracer) } tests := make([]tlsTest, 0) - for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + //for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + for _, scenario := range []uint16{tls.VersionTLS12} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), @@ -1449,9 +1450,11 @@ func (s *TracerSuite) TestTLSClassification() { require.NoError(t, srv.Run(done)) t.Cleanup(func() { close(done) }) tlsConfig := &tls.Config{ - MinVersion: scenario, - MaxVersion: scenario, - InsecureSkipVerify: true, + MinVersion: scenario, + MaxVersion: scenario, + InsecureSkipVerify: true, + SessionTicketsDisabled: true, // Disable session tickets + ClientSessionCache: nil, // Disable session cache } conn, err := net.Dial("tcp", "localhost:"+portAsString) require.NoError(t, err) From 2c5bf65d115cfbff97711775633c0a80fcc0fdc4 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 23 Oct 2024 10:12:55 -0400 Subject: [PATCH 07/53] fall back to TLS version field if extensions aren't present for client --- .../classification/protocol-classification.h | 3 --- pkg/network/ebpf/c/protocols/tls/tls.h | 16 ++++++++----- pkg/network/ebpf/c/tracer/stats.h | 24 ++++++++++++++++++- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 76cef90250503e..006f384dffb1fb 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -183,9 +183,6 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct log_debug("adamk\n"); log_debug("adamk tls classification: parse_tls_payload=%d", ret); ret++; - log_debug("adamk tls classification: client version 1=%d", tags->client_tags.offered_versions[0]); - log_debug("adamk tls classification: server version=%d", tags->server_tags.version); - log_debug("adamk tls classification: server cipher=%d", tags->server_tags.cipher_suite); } // The connection is TLS encrypted, thus we cannot further classify the protocol // using the socket filter and can bail out; diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 41ff9657f70504..f763f9a9cdee34 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -17,6 +17,8 @@ #define TLS_HANDSHAKE 0x16 #define TLS_APPLICATION_DATA 0x17 +#define SUPPORTED_VERSIONS_EXTENSION 0x002B + /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) @@ -104,7 +106,7 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_reco return true; } -// Function to parse ClientHello message +// parse_client_hello reads the ClientHello message from the TLS handshake and populates select tags static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_enhanced_tags_t *tags) { // Move offset past handshake type (1 byte) offset += 1; @@ -115,7 +117,7 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse return -1; __u32 handshake_length = (handshake_length_bytes[0] << 16) | (handshake_length_bytes[1] << 8) | - (handshake_length_bytes[2]); + handshake_length_bytes[2]; offset += 3; // Ensure we don't read beyond the packet @@ -129,6 +131,10 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse client_version = bpf_ntohs(client_version); offset += 2; + // Store client_version in tags (in case supported_versions extension is absent) + tags->client_tags.offered_versions[0] = client_version; + tags->client_tags.num_offered_versions = 1; + // Skip Random (32 bytes) offset += 32; @@ -200,7 +206,7 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse return -1; // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == 0x002B) { + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { // Parse supported_versions extension if (offset + 1 > skb_len) return -1; @@ -335,7 +341,7 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse return -1; // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == 0x002B) { + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { // Parse supported_versions extension if (extension_length != 2) return -1; @@ -374,10 +380,8 @@ static __always_inline int parse_tls_payload(struct __sk_buff *skb, __u64 nh_off return -1; if (handshake_type == TLS_HANDSHAKE_CLIENT_HELLO) { - log_debug("adamk tls classification: client hello"); return parse_client_hello(skb, offset, skb->len, tags); } else if (handshake_type == TLS_HANDSHAKE_SERVER_HELLO) { - log_debug("adamk tls classification: server hello"); return parse_server_hello(skb, offset, skb->len, tags); } else { return -1; diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index 03ab116b8972e8..60ca6cd2a225e7 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -109,6 +109,18 @@ static __always_inline void update_protocol_classification_information(conn_tupl set_protocol_flag(protocol_stack, FLAG_NPM_ENABLED); mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); + // lookup from new map and add info to the connection + tls_enhanced_tags_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); + if (tls_tags) { + // TODO: flush the tags to userspace + log_debug("adamk tls classification: client version 1=%d", tls_tags->client_tags.offered_versions[0]); + log_debug("adamk tls classification: client version 2=%d", tls_tags->client_tags.offered_versions[1]); + log_debug("adamk tls classification: client version 3=%d", tls_tags->client_tags.offered_versions[2]); + log_debug("adamk tls classification: server version=%d", tls_tags->server_tags.version); + log_debug("adamk tls classification: server cipher=%d", tls_tags->server_tags.cipher_suite); + } else { + log_debug("adamk tls classification: no tags found"); + } conn_tuple_t *cached_skb_conn_tup_ptr = bpf_map_lookup_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple_copy); if (!cached_skb_conn_tup_ptr) { @@ -121,7 +133,17 @@ static __always_inline void update_protocol_classification_information(conn_tupl mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); // lookup from new map and add info to the connection - // tls_enhanced_tags_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); + tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); + if (tls_tags) { + // TODO: flush the tags to userspace + log_debug("adamk tls classification 2: client version 1=%d", tls_tags->client_tags.offered_versions[0]); + log_debug("adamk tls classification 2: client version 2=%d", tls_tags->client_tags.offered_versions[1]); + log_debug("adamk tls classification 2: client version 3=%d", tls_tags->client_tags.offered_versions[2]); + log_debug("adamk tls classification 2: server version=%d", tls_tags->server_tags.version); + log_debug("adamk tls classification 2: server cipher=%d", tls_tags->server_tags.cipher_suite); + } else { + log_debug("adamk tls classification 2: no tags found"); + } } static __always_inline void determine_connection_direction(conn_tuple_t *t, conn_stats_ts_t *conn_stats) { From d80f6724ddd81253b19ac56aec9e81a0e95e21fe Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 11 Nov 2024 16:46:45 -0500 Subject: [PATCH 08/53] userspace changes to add tags to proto conn --- .../classification/protocol-classification.h | 2 +- .../classification/shared-tracer-maps.h | 10 +- pkg/network/ebpf/c/protocols/tls/tls.h | 62 ++++--- pkg/network/ebpf/c/tracer/events.h | 3 +- pkg/network/ebpf/c/tracer/stats.h | 20 +-- pkg/network/ebpf/c/tracer/tracer.h | 8 + pkg/network/ebpf/kprobe_types.go | 1 + pkg/network/ebpf/kprobe_types_linux.go | 8 +- pkg/network/encoding/encoding_test.go | 3 + pkg/network/encoding/marshal/format.go | 4 +- pkg/network/event_common.go | 2 + pkg/network/protocols/tls/types.go | 106 ++++++++++++ pkg/network/protocols/tls/types_test.go | 162 ++++++++++++++++++ pkg/network/tags_linux.go | 2 +- pkg/network/tracer/connection/ebpf_tracer.go | 2 + pkg/network/tracer/tracer_test.go | 4 +- 16 files changed, 348 insertions(+), 51 deletions(-) create mode 100644 pkg/network/protocols/tls/types.go create mode 100644 pkg/network/protocols/tls/types_test.go diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 006f384dffb1fb..f01dd3e9e72ad9 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -176,7 +176,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); // Parse TLS payload - tls_enhanced_tags_t *tags = get_or_create_tls_enhanced_tags(&skb_tup); + tls_info_t *tags = get_or_create_tls_enhanced_tags(&skb_tup); if (tags) { // Parse the TLS payload and update the tags int ret = parse_tls_payload(skb, skb_info.data_off, &tls_hdr, tags); diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index cb9d94a5e25c8a..61546eb38e401c 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -10,7 +10,7 @@ // classification procedures on the same connection BPF_HASH_MAP(connection_protocol, conn_tuple_t, protocol_stack_wrapper_t, 0) -BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_enhanced_tags_t, 0) +BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_info_t, 0) static __always_inline protocol_stack_t* __get_protocol_stack(conn_tuple_t* tuple) { protocol_stack_wrapper_t *wrapper = bpf_map_lookup_elem(&connection_protocol, tuple); @@ -20,18 +20,18 @@ static __always_inline protocol_stack_t* __get_protocol_stack(conn_tuple_t* tupl return &wrapper->stack; } -static __always_inline tls_enhanced_tags_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { +static __always_inline tls_info_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { return bpf_map_lookup_elem(&tls_enhanced_tags, tuple); } -static __always_inline tls_enhanced_tags_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { +static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { conn_tuple_t normalized_tup = *tuple; normalize_tuple(&normalized_tup); - tls_enhanced_tags_t *tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + tls_info_t *tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); if (!tags) { // Initialize a new entry - tls_enhanced_tags_t empty_tags = {0}; + tls_info_t empty_tags = {0}; bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags, BPF_NOEXIST); // Lookup again diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index f763f9a9cdee34..07ea63b2c23e80 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -3,7 +3,7 @@ #include "ktypes.h" #include "bpf_builtins.h" -#include "ip.h" +#include "tracer/tracer.h" #define ETH_HLEN 14 // Ethernet header length @@ -30,22 +30,6 @@ typedef struct { __u16 length; } __attribute__((packed)) tls_record_header_t; -// TLS enhanced tags structures -typedef struct { - __u16 offered_versions[6]; - __u8 num_offered_versions; -} tls_client_tags_t; - -typedef struct { - __u16 version; - __u16 cipher_suite; -} tls_server_tags_t; - -typedef struct { - tls_client_tags_t client_tags; - tls_server_tags_t server_tags; -} tls_enhanced_tags_t; - #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 @@ -64,6 +48,31 @@ static __always_inline bool is_valid_tls_version(__u16 version) { } } +static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 version) { + switch (version) { + case TLS_VERSION10: + tls_info->offered_versions |= 0x01; // Bit 0 + break; + case TLS_VERSION11: + tls_info->offered_versions |= 0x02; // Bit 1 + break; + case TLS_VERSION12: + tls_info->offered_versions |= 0x04; // Bit 2 + break; + case TLS_VERSION13: + tls_info->offered_versions |= 0x08; // Bit 3 + break; + case SSL_VERSION20: + tls_info->offered_versions |= 0x10; // Bit 4 + break; + case SSL_VERSION30: + tls_info->offered_versions |= 0x20; // Bit 5 + break; + default: + break; + } +} + // Helper function to read and validate the TLS record header static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { __u32 skb_len = skb->len; @@ -107,7 +116,7 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_reco } // parse_client_hello reads the ClientHello message from the TLS handshake and populates select tags -static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_enhanced_tags_t *tags) { +static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { // Move offset past handshake type (1 byte) offset += 1; @@ -132,8 +141,7 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse offset += 2; // Store client_version in tags (in case supported_versions extension is absent) - tags->client_tags.offered_versions[0] = client_version; - tags->client_tags.num_offered_versions = 1; + set_tls_offered_version(tags, client_version); // Skip Random (32 bytes) offset += 32; @@ -230,10 +238,8 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse return -1; sv_version = bpf_ntohs(sv_version); offset += 2; - - tags->client_tags.offered_versions[num_versions] = sv_version; + set_tls_offered_version(tags, sv_version); } - tags->client_tags.num_offered_versions = num_versions; } else { // Skip other extensions offset += extension_length; @@ -246,7 +252,7 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse } // Function to parse ServerHello message -static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_enhanced_tags_t *tags) { +static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { // Move offset past handshake type (1 byte) offset += 1; @@ -273,7 +279,7 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse // Set the version here and try to get the "real" version from the extensions // Note: In TLS 1.3, the server_version field is set to 0x0303 (TLS 1.2) // The actual version is embedded in the supported_versions extension - tags->server_tags.version = server_version; + tags->chosen_version = server_version; offset += 2; // Skip Random (32 bytes) @@ -299,7 +305,7 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse offset += 1; // Store parsed data into tags - tags->server_tags.cipher_suite = cipher_suite; + tags->cipher_suite = cipher_suite; // Check if there are extensions if (offset < handshake_end) { @@ -356,7 +362,7 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse selected_version = bpf_ntohs(selected_version); offset += 2; - tags->server_tags.version = selected_version; + tags->chosen_version = selected_version; } else { // Skip other extensions offset += extension_length; @@ -370,7 +376,7 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse } // Function to parse the TLS payload and update tls_enhanced_tags_t -static __always_inline int parse_tls_payload(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr, tls_enhanced_tags_t *tags) { +static __always_inline int parse_tls_payload(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr, tls_info_t *tags) { // At this point, tls_hdr has already been validated and filled by is_tls() __u64 offset = nh_off + sizeof(tls_record_header_t); diff --git a/pkg/network/ebpf/c/tracer/events.h b/pkg/network/ebpf/c/tracer/events.h index 778070c28ec8e8..0d8b610c116815 100644 --- a/pkg/network/ebpf/c/tracer/events.h +++ b/pkg/network/ebpf/c/tracer/events.h @@ -23,7 +23,7 @@ static __always_inline void clean_protocol_classification(conn_tuple_t *tup) { conn_tuple.netns = 0; normalize_tuple(&conn_tuple); delete_protocol_stack(&conn_tuple, NULL, FLAG_TCP_CLOSE_DELETION); - // TODO: delete TLS enhanced tags + bpf_map_delete_elem(&tls_enhanced_tags, &conn_tuple); conn_tuple_t *skb_tup_ptr = bpf_map_lookup_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple); if (skb_tup_ptr == NULL) { @@ -32,6 +32,7 @@ static __always_inline void clean_protocol_classification(conn_tuple_t *tup) { conn_tuple_t skb_tup = *skb_tup_ptr; delete_protocol_stack(&skb_tup, NULL, FLAG_TCP_CLOSE_DELETION); + bpf_map_delete_elem(&tls_enhanced_tags, &conn_tuple); bpf_map_delete_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple); } diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index 60ca6cd2a225e7..574073ac1da63e 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -110,14 +110,13 @@ static __always_inline void update_protocol_classification_information(conn_tupl mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); // lookup from new map and add info to the connection - tls_enhanced_tags_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); + tls_info_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); if (tls_tags) { // TODO: flush the tags to userspace - log_debug("adamk tls classification: client version 1=%d", tls_tags->client_tags.offered_versions[0]); - log_debug("adamk tls classification: client version 2=%d", tls_tags->client_tags.offered_versions[1]); - log_debug("adamk tls classification: client version 3=%d", tls_tags->client_tags.offered_versions[2]); - log_debug("adamk tls classification: server version=%d", tls_tags->server_tags.version); - log_debug("adamk tls classification: server cipher=%d", tls_tags->server_tags.cipher_suite); + log_debug("adamk tls classification: client version 1=%d", tls_tags->offered_versions); + log_debug("adamk tls classification: server version=%d", tls_tags->chosen_version); + log_debug("adamk tls classification: server cipher=%d", tls_tags->cipher_suite); + stats->tls_tags = *tls_tags; } else { log_debug("adamk tls classification: no tags found"); } @@ -136,11 +135,10 @@ static __always_inline void update_protocol_classification_information(conn_tupl tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); if (tls_tags) { // TODO: flush the tags to userspace - log_debug("adamk tls classification 2: client version 1=%d", tls_tags->client_tags.offered_versions[0]); - log_debug("adamk tls classification 2: client version 2=%d", tls_tags->client_tags.offered_versions[1]); - log_debug("adamk tls classification 2: client version 3=%d", tls_tags->client_tags.offered_versions[2]); - log_debug("adamk tls classification 2: server version=%d", tls_tags->server_tags.version); - log_debug("adamk tls classification 2: server cipher=%d", tls_tags->server_tags.cipher_suite); + log_debug("adamk tls classification: client version 1=%d", tls_tags->offered_versions); + log_debug("adamk tls classification: server version=%d", tls_tags->chosen_version); + log_debug("adamk tls classification: server cipher=%d", tls_tags->cipher_suite); + stats->tls_tags = *tls_tags; } else { log_debug("adamk tls classification 2: no tags found"); } diff --git a/pkg/network/ebpf/c/tracer/tracer.h b/pkg/network/ebpf/c/tracer/tracer.h index f99301993b539c..4c8088c2168c25 100644 --- a/pkg/network/ebpf/c/tracer/tracer.h +++ b/pkg/network/ebpf/c/tracer/tracer.h @@ -29,6 +29,13 @@ typedef enum { #define CONN_DIRECTION_MASK 0b11 +typedef struct { + __u16 chosen_version; // 2 bytes + __u16 cipher_suite; // 2 bytes + __u8 offered_versions; // 1 byte (6 bits used) + __u8 reserved; // 1 byte (for alignment or future use) +} tls_info_t; + typedef struct { __u64 sent_bytes; __u64 recv_bytes; @@ -54,6 +61,7 @@ typedef struct { protocol_stack_t protocol_stack; __u8 flags; __u8 direction; + tls_info_t tls_tags; } conn_stats_ts_t; // Connection flags diff --git a/pkg/network/ebpf/kprobe_types.go b/pkg/network/ebpf/kprobe_types.go index 6745cdc0b7fee0..3347ad604d7b78 100644 --- a/pkg/network/ebpf/kprobe_types.go +++ b/pkg/network/ebpf/kprobe_types.go @@ -31,6 +31,7 @@ type UDPRecvSock C.udp_recv_sock_t type BindSyscallArgs C.bind_syscall_args_t type ProtocolStack C.protocol_stack_t type ProtocolStackWrapper C.protocol_stack_wrapper_t +type TLSTags C.tls_info_t // udp_recv_sock_t have *sock and *msghdr struct members, we make them opaque here type _Ctype_struct_sock uint64 diff --git a/pkg/network/ebpf/kprobe_types_linux.go b/pkg/network/ebpf/kprobe_types_linux.go index 13d63751f2adc5..3d21324fdaeddd 100644 --- a/pkg/network/ebpf/kprobe_types_linux.go +++ b/pkg/network/ebpf/kprobe_types_linux.go @@ -31,7 +31,7 @@ type ConnStats struct { Protocol_stack ProtocolStack Flags uint8 Direction uint8 - Pad_cgo_0 [6]byte + Tls_tags TLSTags } type Conn struct { Tup ConnTuple @@ -108,6 +108,12 @@ type ProtocolStackWrapper struct { Stack ProtocolStack Updated uint64 } +type TLSTags struct { + Chosen_version uint16 + Cipher_suite uint16 + Offered_versions uint8 + Reserved uint8 +} type _Ctype_struct_sock uint64 type _Ctype_struct_msghdr uint64 diff --git a/pkg/network/encoding/encoding_test.go b/pkg/network/encoding/encoding_test.go index 04b033d8d52699..e37b9a36b6585f 100644 --- a/pkg/network/encoding/encoding_test.go +++ b/pkg/network/encoding/encoding_test.go @@ -23,6 +23,7 @@ import ( pkgconfigsetup "github.com/DataDog/datadog-agent/pkg/config/setup" "github.com/DataDog/datadog-agent/pkg/network" "github.com/DataDog/datadog-agent/pkg/network/dns" + "github.com/DataDog/datadog-agent/pkg/network/ebpf" "github.com/DataDog/datadog-agent/pkg/network/encoding/marshal" "github.com/DataDog/datadog-agent/pkg/network/encoding/unmarshal" "github.com/DataDog/datadog-agent/pkg/network/protocols" @@ -229,6 +230,7 @@ func TestSerialization(t *testing.T) { }, }, ProtocolStack: protocols.Stack{Application: protocols.HTTP}, + TLSTags: ebpf.TLSTags{Chosen_version: 0, Cipher_suite: 1, Offered_versions: 0}, }, {ConnectionTuple: network.ConnectionTuple{ Source: util.AddressFromString("10.1.1.1"), @@ -241,6 +243,7 @@ func TestSerialization(t *testing.T) { Direction: network.LOCAL, StaticTags: tagOpenSSL | tagTLS, ProtocolStack: protocols.Stack{Application: protocols.HTTP2}, + TLSTags: ebpf.TLSTags{Chosen_version: 0, Cipher_suite: 1, Offered_versions: 0}, DNSStats: map[dns.Hostname]map[dns.QueryType]dns.Stats{ dns.ToHostname("foo.com"): { dns.TypeA: { diff --git a/pkg/network/encoding/marshal/format.go b/pkg/network/encoding/marshal/format.go index 4628a0203e390e..9fc457ac85cc71 100644 --- a/pkg/network/encoding/marshal/format.go +++ b/pkg/network/encoding/marshal/format.go @@ -13,6 +13,7 @@ import ( model "github.com/DataDog/agent-payload/v5/process" "github.com/DataDog/datadog-agent/pkg/network" + "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/process/util" ) @@ -120,9 +121,10 @@ func FormatConnection(builder *model.ConnectionBuilder, conn network.ConnectionS httpStaticTags, httpDynamicTags := httpEncoder.GetHTTPAggregationsAndTags(conn, builder) http2StaticTags, http2DynamicTags := http2Encoder.WriteHTTP2AggregationsAndTags(conn, builder) + tlsDynamicTags := tls.GetTLSDynamicTags(&conn.TLSTags) staticTags := httpStaticTags | http2StaticTags - dynamicTags := mergeDynamicTags(httpDynamicTags, http2DynamicTags) + dynamicTags := mergeDynamicTags(httpDynamicTags, http2DynamicTags, tlsDynamicTags) staticTags |= kafkaEncoder.WriteKafkaAggregations(conn, builder) staticTags |= postgresEncoder.WritePostgresAggregations(conn, builder) diff --git a/pkg/network/event_common.go b/pkg/network/event_common.go index e8dcbfd2e77c3b..ecb14dde26b789 100644 --- a/pkg/network/event_common.go +++ b/pkg/network/event_common.go @@ -18,6 +18,7 @@ import ( "go4.org/intern" "github.com/DataDog/datadog-agent/pkg/network/dns" + "github.com/DataDog/datadog-agent/pkg/network/ebpf" "github.com/DataDog/datadog-agent/pkg/network/protocols" "github.com/DataDog/datadog-agent/pkg/network/protocols/http" "github.com/DataDog/datadog-agent/pkg/network/protocols/kafka" @@ -283,6 +284,7 @@ type ConnectionStats struct { RTTVar uint32 StaticTags uint64 ProtocolStack protocols.Stack + TLSTags ebpf.TLSTags // keep these fields last because they are 1 byte each and otherwise inflate the struct size due to alignment Direction ConnectionDirection diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go new file mode 100644 index 00000000000000..c41211cbae36a6 --- /dev/null +++ b/pkg/network/protocols/tls/types.go @@ -0,0 +1,106 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2024-present Datadog, Inc. + +package tls + +import ( + "fmt" + + "github.com/DataDog/datadog-agent/pkg/network/ebpf" +) + +// TLS and SSL version constants +const ( + SSLVersion20 uint16 = 0x0200 + SSLVersion30 uint16 = 0x0300 + TLSVersion10 uint16 = 0x0301 + TLSVersion11 uint16 = 0x0302 + TLSVersion12 uint16 = 0x0303 + TLSVersion13 uint16 = 0x0304 +) + +// Centralized mapping of version constants to their string representations +var tlsVersionNames = map[uint16]string{ + SSLVersion20: "SSL 2.0", + SSLVersion30: "SSL 3.0", + TLSVersion10: "TLS 1.0", + TLSVersion11: "TLS 1.1", + TLSVersion12: "TLS 1.2", + TLSVersion13: "TLS 1.3", +} + +// Bitmask constants for Offered_versions +const ( + OfferedTLSVersion10 uint8 = 0x01 // Bit 0 + OfferedTLSVersion11 uint8 = 0x02 // Bit 1 + OfferedTLSVersion12 uint8 = 0x04 // Bit 2 + OfferedTLSVersion13 uint8 = 0x08 // Bit 3 + OfferedSSLVersion20 uint8 = 0x10 // Bit 4 + OfferedSSLVersion30 uint8 = 0x20 // Bit 5 +) + +// Mapping of offered version bitmasks to version constants +var offeredVersionBitmask = []struct { + bitMask uint8 + version uint16 +}{ + {OfferedSSLVersion20, SSLVersion20}, + {OfferedSSLVersion30, SSLVersion30}, + {OfferedTLSVersion10, TLSVersion10}, + {OfferedTLSVersion11, TLSVersion11}, + {OfferedTLSVersion12, TLSVersion12}, + {OfferedTLSVersion13, TLSVersion13}, +} + +// Constants for tag keys +const ( + tagTLSVersion = "tls.version:" + tagTLSCipherSuiteID = "tls.cipher_suite_id:" + tagTLSClientVersion = "tls.client_version:" +) + +// FormatTLSVersion converts a version uint16 to its string representation +func FormatTLSVersion(version uint16) string { + if name, ok := tlsVersionNames[version]; ok { + return name + } + return "" +} + +// parseOfferedVersions parses the Offered_versions bitmask into a slice of version strings +func parseOfferedVersions(offeredVersions uint8) []string { + var versions []string + for _, ov := range offeredVersionBitmask { + if (offeredVersions & ov.bitMask) != 0 { + if name := tlsVersionNames[ov.version]; name != "" { + versions = append(versions, name) + } + } + } + return versions +} + +// GetTLSDynamicTags generates dynamic tags based on TLS information +func GetTLSDynamicTags(tls *ebpf.TLSTags) map[string]struct{} { + tags := make(map[string]struct{}) + if tls == nil { + return tags + } + + // Server chosen version + if versionName := FormatTLSVersion(tls.Chosen_version); versionName != "" { + tags[tagTLSVersion+versionName] = struct{}{} + } + + // Cipher suite ID as hex string + tags[tagTLSCipherSuiteID+fmt.Sprintf("0x%04X", tls.Cipher_suite)] = struct{}{} + + // Client offered versions + for _, versionName := range parseOfferedVersions(tls.Offered_versions) { + tags[tagTLSClientVersion+versionName] = struct{}{} + } + + return tags +} diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go new file mode 100644 index 00000000000000..060d2fe45fcabc --- /dev/null +++ b/pkg/network/protocols/tls/types_test.go @@ -0,0 +1,162 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2024-present Datadog, Inc. + +package tls + +import ( + "fmt" + "reflect" + "testing" + + "github.com/DataDog/datadog-agent/pkg/network/ebpf" +) + +func TestFormatTLSVersion(t *testing.T) { + tests := []struct { + version uint16 + expected string + }{ + {TLSVersion10, "TLS 1.0"}, + {TLSVersion11, "TLS 1.1"}, + {TLSVersion12, "TLS 1.2"}, + {TLSVersion13, "TLS 1.3"}, + {SSLVersion20, "SSL 2.0"}, + {SSLVersion30, "SSL 3.0"}, + {0xFFFF, ""}, // Unknown version + {0x0000, ""}, // Zero value + {0x0305, ""}, // Version just above known versions + {0x01FF, ""}, // Random unknown version + } + + for _, test := range tests { + t.Run(fmt.Sprintf("Version_0x%04X", test.version), func(t *testing.T) { + result := FormatTLSVersion(test.version) + if result != test.expected { + t.Errorf("FormatTLSVersion(0x%04X) = %q; want %q", test.version, result, test.expected) + } + }) + } +} + +func TestParseOfferedVersions(t *testing.T) { + tests := []struct { + offeredVersions uint8 + expected []string + }{ + {0x00, []string{}}, // No versions offered + {OfferedTLSVersion10, []string{"TLS 1.0"}}, + {OfferedTLSVersion11, []string{"TLS 1.1"}}, + {OfferedTLSVersion12, []string{"TLS 1.2"}}, + {OfferedTLSVersion13, []string{"TLS 1.3"}}, + {OfferedSSLVersion20, []string{"SSL 2.0"}}, + {OfferedSSLVersion30, []string{"SSL 3.0"}}, + {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"TLS 1.0", "TLS 1.2"}}, + {OfferedTLSVersion11 | OfferedTLSVersion13 | OfferedSSLVersion30, []string{"TLS 1.1", "TLS 1.3", "SSL 3.0"}}, + {0xFF, []string{"TLS 1.0", "TLS 1.1", "TLS 1.2", "TLS 1.3", "SSL 2.0", "SSL 3.0"}}, // All bits set + {0x40, []string{}}, // Undefined bit set + {0x80, []string{}}, // Undefined bit set + } + + for _, test := range tests { + t.Run(fmt.Sprintf("OfferedVersions_0x%02X", test.offeredVersions), func(t *testing.T) { + result := parseOfferedVersions(test.offeredVersions) + if !reflect.DeepEqual(result, test.expected) { + t.Errorf("parseOfferedVersions(0x%02X) = %v; want %v", test.offeredVersions, result, test.expected) + } + }) + } +} + +func TestGetTLSDynamicTags(t *testing.T) { + tests := []struct { + name string + tlsTags *ebpf.TLSTags + expected map[string]struct{} + }{ + { + name: "Nil_TLSTags", + tlsTags: nil, + expected: map[string]struct{}{}, + }, + { + name: "All_Fields_Populated", + tlsTags: &ebpf.TLSTags{ + Chosen_version: TLSVersion12, + Cipher_suite: 0x009C, + Offered_versions: OfferedTLSVersion11 | OfferedTLSVersion12, + }, + expected: map[string]struct{}{ + "tls.version:TLS 1.2": {}, + "tls.cipher_suite_id:0x009C": {}, + "tls.client_version:TLS 1.1": {}, + "tls.client_version:TLS 1.2": {}, + }, + }, + { + name: "Unknown_Chosen_Version", + tlsTags: &ebpf.TLSTags{ + Chosen_version: 0xFFFF, // Unknown version + Cipher_suite: 0x00FF, + Offered_versions: OfferedTLSVersion13, + }, + expected: map[string]struct{}{ + "tls.cipher_suite_id:0x00FF": {}, + "tls.client_version:TLS 1.3": {}, + }, + }, + { + name: "No_Offered_Versions", + tlsTags: &ebpf.TLSTags{ + Chosen_version: TLSVersion13, + Cipher_suite: 0x1301, + Offered_versions: 0x00, + }, + expected: map[string]struct{}{ + "tls.version:TLS 1.3": {}, + "tls.cipher_suite_id:0x1301": {}, + }, + }, + { + name: "Zero_Cipher_Suite", + tlsTags: &ebpf.TLSTags{ + Chosen_version: TLSVersion10, + Cipher_suite: 0x0000, + Offered_versions: OfferedTLSVersion10, + }, + expected: map[string]struct{}{ + "tls.version:TLS 1.0": {}, + "tls.cipher_suite_id:0x0000": {}, + "tls.client_version:TLS 1.0": {}, + }, + }, + { + name: "All_Bits_Set_In_Offered_Versions", + tlsTags: &ebpf.TLSTags{ + Chosen_version: TLSVersion12, + Cipher_suite: 0xC02F, + Offered_versions: 0xFF, // All bits set + }, + expected: map[string]struct{}{ + "tls.version:TLS 1.2": {}, + "tls.cipher_suite_id:0xC02F": {}, + "tls.client_version:TLS 1.0": {}, + "tls.client_version:TLS 1.1": {}, + "tls.client_version:TLS 1.2": {}, + "tls.client_version:TLS 1.3": {}, + "tls.client_version:SSL 2.0": {}, + "tls.client_version:SSL 3.0": {}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := GetTLSDynamicTags(test.tlsTags) + if !reflect.DeepEqual(result, test.expected) { + t.Errorf("GetTLSDynamicTags(%v) = %v; want %v", test.tlsTags, result, test.expected) + } + }) + } +} diff --git a/pkg/network/tags_linux.go b/pkg/network/tags_linux.go index 4f81692cfc31a5..4c2261a2457684 100644 --- a/pkg/network/tags_linux.go +++ b/pkg/network/tags_linux.go @@ -26,7 +26,7 @@ const ( ConnTagNodeJS = http.NodeJS ) -// GetStaticTags return the string list of static tags from network.ConnectionStats.Tags +// GetStaticTags return the string list of static tags from network.ConnectionStats.StaticTags func GetStaticTags(staticTags uint64) (tags []string) { for tag, str := range http.StaticTags { if (staticTags & tag) > 0 { diff --git a/pkg/network/tracer/connection/ebpf_tracer.go b/pkg/network/tracer/connection/ebpf_tracer.go index e529b4abf76c7f..8e66b467c0b25e 100644 --- a/pkg/network/tracer/connection/ebpf_tracer.go +++ b/pkg/network/tracer/connection/ebpf_tracer.go @@ -763,6 +763,8 @@ func populateConnStats(stats *network.ConnectionStats, t *netebpf.ConnTuple, s * Encryption: protocols.Encryption(s.Protocol_stack.Encryption), } + stats.TLSTags = s.Tls_tags + if t.Type() == netebpf.TCP { stats.Type = network.TCP } else { diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index 2496cc35438bfe..dae63cc7ed8970 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -1470,8 +1470,8 @@ func (s *TracerSuite) TestTLSClassification() { validation func(t *testing.T, tr *Tracer) } tests := make([]tlsTest, 0) - //for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { - for _, scenario := range []uint16{tls.VersionTLS12} { + for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + //for _, scenario := range []uint16{tls.VersionTLS12} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), From 152ae540606cd6380fec6e51736910181e89b4ff Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 12 Nov 2024 16:41:22 -0500 Subject: [PATCH 09/53] clean up and test for presence of version tag --- pkg/network/ebpf/c/protocols/tls/tls.h | 12 ++-- pkg/network/encoding/encoding_test.go | 6 +- pkg/network/event_common.go | 4 +- pkg/network/protocols/tls/types.go | 71 +++++++++++++------- pkg/network/protocols/tls/types_test.go | 56 +++++++-------- pkg/network/state.go | 2 + pkg/network/tracer/connection/ebpf_tracer.go | 7 +- pkg/network/tracer/tracer_test.go | 15 +++-- 8 files changed, 102 insertions(+), 71 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 07ea63b2c23e80..c68275cc3951a3 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -50,22 +50,22 @@ static __always_inline bool is_valid_tls_version(__u16 version) { static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 version) { switch (version) { - case TLS_VERSION10: + case SSL_VERSION20: tls_info->offered_versions |= 0x01; // Bit 0 break; - case TLS_VERSION11: + case SSL_VERSION30: tls_info->offered_versions |= 0x02; // Bit 1 break; - case TLS_VERSION12: + case TLS_VERSION10: tls_info->offered_versions |= 0x04; // Bit 2 break; - case TLS_VERSION13: + case TLS_VERSION11: tls_info->offered_versions |= 0x08; // Bit 3 break; - case SSL_VERSION20: + case TLS_VERSION12: tls_info->offered_versions |= 0x10; // Bit 4 break; - case SSL_VERSION30: + case TLS_VERSION13: tls_info->offered_versions |= 0x20; // Bit 5 break; default: diff --git a/pkg/network/encoding/encoding_test.go b/pkg/network/encoding/encoding_test.go index e37b9a36b6585f..e12264c571963a 100644 --- a/pkg/network/encoding/encoding_test.go +++ b/pkg/network/encoding/encoding_test.go @@ -23,13 +23,13 @@ import ( pkgconfigsetup "github.com/DataDog/datadog-agent/pkg/config/setup" "github.com/DataDog/datadog-agent/pkg/network" "github.com/DataDog/datadog-agent/pkg/network/dns" - "github.com/DataDog/datadog-agent/pkg/network/ebpf" "github.com/DataDog/datadog-agent/pkg/network/encoding/marshal" "github.com/DataDog/datadog-agent/pkg/network/encoding/unmarshal" "github.com/DataDog/datadog-agent/pkg/network/protocols" "github.com/DataDog/datadog-agent/pkg/network/protocols/http" "github.com/DataDog/datadog-agent/pkg/network/protocols/kafka" "github.com/DataDog/datadog-agent/pkg/network/protocols/telemetry" + "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/process/util" ) @@ -230,7 +230,7 @@ func TestSerialization(t *testing.T) { }, }, ProtocolStack: protocols.Stack{Application: protocols.HTTP}, - TLSTags: ebpf.TLSTags{Chosen_version: 0, Cipher_suite: 1, Offered_versions: 0}, + TLSTags: tls.Tags{ChosenVersion: 0, CipherSuite: 0, OfferedVersions: 0}, }, {ConnectionTuple: network.ConnectionTuple{ Source: util.AddressFromString("10.1.1.1"), @@ -243,7 +243,7 @@ func TestSerialization(t *testing.T) { Direction: network.LOCAL, StaticTags: tagOpenSSL | tagTLS, ProtocolStack: protocols.Stack{Application: protocols.HTTP2}, - TLSTags: ebpf.TLSTags{Chosen_version: 0, Cipher_suite: 1, Offered_versions: 0}, + TLSTags: tls.Tags{ChosenVersion: 0, CipherSuite: 0, OfferedVersions: 0}, DNSStats: map[dns.Hostname]map[dns.QueryType]dns.Stats{ dns.ToHostname("foo.com"): { dns.TypeA: { diff --git a/pkg/network/event_common.go b/pkg/network/event_common.go index ecb14dde26b789..d618ec9302e662 100644 --- a/pkg/network/event_common.go +++ b/pkg/network/event_common.go @@ -18,12 +18,12 @@ import ( "go4.org/intern" "github.com/DataDog/datadog-agent/pkg/network/dns" - "github.com/DataDog/datadog-agent/pkg/network/ebpf" "github.com/DataDog/datadog-agent/pkg/network/protocols" "github.com/DataDog/datadog-agent/pkg/network/protocols/http" "github.com/DataDog/datadog-agent/pkg/network/protocols/kafka" "github.com/DataDog/datadog-agent/pkg/network/protocols/postgres" "github.com/DataDog/datadog-agent/pkg/network/protocols/redis" + "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/process/util" ) @@ -284,7 +284,7 @@ type ConnectionStats struct { RTTVar uint32 StaticTags uint64 ProtocolStack protocols.Stack - TLSTags ebpf.TLSTags + TLSTags tls.Tags // keep these fields last because they are 1 byte each and otherwise inflate the struct size due to alignment Direction ConnectionDirection diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index c41211cbae36a6..ccffa36f7b6b4f 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -5,11 +5,7 @@ package tls -import ( - "fmt" - - "github.com/DataDog/datadog-agent/pkg/network/ebpf" -) +import "fmt" // TLS and SSL version constants const ( @@ -21,7 +17,17 @@ const ( TLSVersion13 uint16 = 0x0304 ) -// Centralized mapping of version constants to their string representations +// Bitmask constants for Offered_versions +const ( + OfferedSSLVersion20 uint8 = 0x01 // Bit 0 + OfferedSSLVersion30 uint8 = 0x02 // Bit 1 + OfferedTLSVersion10 uint8 = 0x04 // Bit 2 + OfferedTLSVersion11 uint8 = 0x08 // Bit 3 + OfferedTLSVersion12 uint8 = 0x10 // Bit 4 + OfferedTLSVersion13 uint8 = 0x20 // Bit 5 +) + +// mapping of version constants to their string representations var tlsVersionNames = map[uint16]string{ SSLVersion20: "SSL 2.0", SSLVersion30: "SSL 3.0", @@ -31,16 +37,6 @@ var tlsVersionNames = map[uint16]string{ TLSVersion13: "TLS 1.3", } -// Bitmask constants for Offered_versions -const ( - OfferedTLSVersion10 uint8 = 0x01 // Bit 0 - OfferedTLSVersion11 uint8 = 0x02 // Bit 1 - OfferedTLSVersion12 uint8 = 0x04 // Bit 2 - OfferedTLSVersion13 uint8 = 0x08 // Bit 3 - OfferedSSLVersion20 uint8 = 0x10 // Bit 4 - OfferedSSLVersion30 uint8 = 0x20 // Bit 5 -) - // Mapping of offered version bitmasks to version constants var offeredVersionBitmask = []struct { bitMask uint8 @@ -56,11 +52,38 @@ var offeredVersionBitmask = []struct { // Constants for tag keys const ( - tagTLSVersion = "tls.version:" + TagTLSVersion = "tls.version:" tagTLSCipherSuiteID = "tls.cipher_suite_id:" tagTLSClientVersion = "tls.client_version:" ) +type Tags struct { + ChosenVersion uint16 + CipherSuite uint16 + OfferedVersions uint8 +} + +func (t *Tags) MergeWith(that Tags) { + if t.ChosenVersion == 0 { + t.ChosenVersion = that.ChosenVersion + } + if t.CipherSuite == 0 { + t.CipherSuite = that.CipherSuite + } + if t.OfferedVersions == 0 { + t.OfferedVersions = that.OfferedVersions + } + +} + +func (t *Tags) IsEmpty() bool { + return t.ChosenVersion == 0 && t.CipherSuite == 0 && t.OfferedVersions == 0 +} + +func (t *Tags) String() string { + return fmt.Sprintf("ChosenVersion: %d, CipherSuite: %d, OfferedVersions: %d", t.ChosenVersion, t.CipherSuite, t.OfferedVersions) +} + // FormatTLSVersion converts a version uint16 to its string representation func FormatTLSVersion(version uint16) string { if name, ok := tlsVersionNames[version]; ok { @@ -71,7 +94,7 @@ func FormatTLSVersion(version uint16) string { // parseOfferedVersions parses the Offered_versions bitmask into a slice of version strings func parseOfferedVersions(offeredVersions uint8) []string { - var versions []string + versions := []string{} for _, ov := range offeredVersionBitmask { if (offeredVersions & ov.bitMask) != 0 { if name := tlsVersionNames[ov.version]; name != "" { @@ -83,22 +106,24 @@ func parseOfferedVersions(offeredVersions uint8) []string { } // GetTLSDynamicTags generates dynamic tags based on TLS information -func GetTLSDynamicTags(tls *ebpf.TLSTags) map[string]struct{} { +func GetTLSDynamicTags(tls *Tags) map[string]struct{} { tags := make(map[string]struct{}) if tls == nil { return tags } // Server chosen version - if versionName := FormatTLSVersion(tls.Chosen_version); versionName != "" { - tags[tagTLSVersion+versionName] = struct{}{} + if versionName := FormatTLSVersion(tls.ChosenVersion); versionName != "" { + tags[TagTLSVersion+versionName] = struct{}{} } // Cipher suite ID as hex string - tags[tagTLSCipherSuiteID+fmt.Sprintf("0x%04X", tls.Cipher_suite)] = struct{}{} + if tls.CipherSuite != 0 { + tags[tagTLSCipherSuiteID+fmt.Sprintf("0x%04X", tls.CipherSuite)] = struct{}{} + } // Client offered versions - for _, versionName := range parseOfferedVersions(tls.Offered_versions) { + for _, versionName := range parseOfferedVersions(tls.OfferedVersions) { tags[tagTLSClientVersion+versionName] = struct{}{} } diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go index 060d2fe45fcabc..5b5f6258ec88fd 100644 --- a/pkg/network/protocols/tls/types_test.go +++ b/pkg/network/protocols/tls/types_test.go @@ -9,8 +9,6 @@ import ( "fmt" "reflect" "testing" - - "github.com/DataDog/datadog-agent/pkg/network/ebpf" ) func TestFormatTLSVersion(t *testing.T) { @@ -18,12 +16,12 @@ func TestFormatTLSVersion(t *testing.T) { version uint16 expected string }{ + {SSLVersion20, "SSL 2.0"}, + {SSLVersion30, "SSL 3.0"}, {TLSVersion10, "TLS 1.0"}, {TLSVersion11, "TLS 1.1"}, {TLSVersion12, "TLS 1.2"}, {TLSVersion13, "TLS 1.3"}, - {SSLVersion20, "SSL 2.0"}, - {SSLVersion30, "SSL 3.0"}, {0xFFFF, ""}, // Unknown version {0x0000, ""}, // Zero value {0x0305, ""}, // Version just above known versions @@ -46,15 +44,15 @@ func TestParseOfferedVersions(t *testing.T) { expected []string }{ {0x00, []string{}}, // No versions offered + {OfferedSSLVersion20, []string{"SSL 2.0"}}, + {OfferedSSLVersion30, []string{"SSL 3.0"}}, {OfferedTLSVersion10, []string{"TLS 1.0"}}, {OfferedTLSVersion11, []string{"TLS 1.1"}}, {OfferedTLSVersion12, []string{"TLS 1.2"}}, {OfferedTLSVersion13, []string{"TLS 1.3"}}, - {OfferedSSLVersion20, []string{"SSL 2.0"}}, - {OfferedSSLVersion30, []string{"SSL 3.0"}}, {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"TLS 1.0", "TLS 1.2"}}, - {OfferedTLSVersion11 | OfferedTLSVersion13 | OfferedSSLVersion30, []string{"TLS 1.1", "TLS 1.3", "SSL 3.0"}}, - {0xFF, []string{"TLS 1.0", "TLS 1.1", "TLS 1.2", "TLS 1.3", "SSL 2.0", "SSL 3.0"}}, // All bits set + {OfferedSSLVersion30 | OfferedTLSVersion11 | OfferedTLSVersion13, []string{"SSL 3.0", "TLS 1.1", "TLS 1.3"}}, + {0xFF, []string{"SSL 2.0", "SSL 3.0", "TLS 1.0", "TLS 1.1", "TLS 1.2", "TLS 1.3"}}, // All bits set {0x40, []string{}}, // Undefined bit set {0x80, []string{}}, // Undefined bit set } @@ -72,7 +70,7 @@ func TestParseOfferedVersions(t *testing.T) { func TestGetTLSDynamicTags(t *testing.T) { tests := []struct { name string - tlsTags *ebpf.TLSTags + tlsTags *Tags expected map[string]struct{} }{ { @@ -82,10 +80,10 @@ func TestGetTLSDynamicTags(t *testing.T) { }, { name: "All_Fields_Populated", - tlsTags: &ebpf.TLSTags{ - Chosen_version: TLSVersion12, - Cipher_suite: 0x009C, - Offered_versions: OfferedTLSVersion11 | OfferedTLSVersion12, + tlsTags: &Tags{ + ChosenVersion: TLSVersion12, + CipherSuite: 0x009C, + OfferedVersions: OfferedTLSVersion11 | OfferedTLSVersion12, }, expected: map[string]struct{}{ "tls.version:TLS 1.2": {}, @@ -96,10 +94,10 @@ func TestGetTLSDynamicTags(t *testing.T) { }, { name: "Unknown_Chosen_Version", - tlsTags: &ebpf.TLSTags{ - Chosen_version: 0xFFFF, // Unknown version - Cipher_suite: 0x00FF, - Offered_versions: OfferedTLSVersion13, + tlsTags: &Tags{ + ChosenVersion: 0xFFFF, // Unknown version + CipherSuite: 0x00FF, + OfferedVersions: OfferedTLSVersion13, }, expected: map[string]struct{}{ "tls.cipher_suite_id:0x00FF": {}, @@ -108,10 +106,10 @@ func TestGetTLSDynamicTags(t *testing.T) { }, { name: "No_Offered_Versions", - tlsTags: &ebpf.TLSTags{ - Chosen_version: TLSVersion13, - Cipher_suite: 0x1301, - Offered_versions: 0x00, + tlsTags: &Tags{ + ChosenVersion: TLSVersion13, + CipherSuite: 0x1301, + OfferedVersions: 0x00, }, expected: map[string]struct{}{ "tls.version:TLS 1.3": {}, @@ -120,23 +118,21 @@ func TestGetTLSDynamicTags(t *testing.T) { }, { name: "Zero_Cipher_Suite", - tlsTags: &ebpf.TLSTags{ - Chosen_version: TLSVersion10, - Cipher_suite: 0x0000, - Offered_versions: OfferedTLSVersion10, + tlsTags: &Tags{ + ChosenVersion: TLSVersion10, + OfferedVersions: OfferedTLSVersion10, }, expected: map[string]struct{}{ "tls.version:TLS 1.0": {}, - "tls.cipher_suite_id:0x0000": {}, "tls.client_version:TLS 1.0": {}, }, }, { name: "All_Bits_Set_In_Offered_Versions", - tlsTags: &ebpf.TLSTags{ - Chosen_version: TLSVersion12, - Cipher_suite: 0xC02F, - Offered_versions: 0xFF, // All bits set + tlsTags: &Tags{ + ChosenVersion: TLSVersion12, + CipherSuite: 0xC02F, + OfferedVersions: 0xFF, // All bits set }, expected: map[string]struct{}{ "tls.version:TLS 1.2": {}, diff --git a/pkg/network/state.go b/pkg/network/state.go index 155e75e78ca075..85bc10c0d1437d 100644 --- a/pkg/network/state.go +++ b/pkg/network/state.go @@ -1421,6 +1421,7 @@ func (ac *aggregateConnection) merge(c *ConnectionStats) { } ac.ProtocolStack.MergeWith(c.ProtocolStack) + ac.TLSTags.MergeWith(c.TLSTags) if ac.DNSStats == nil { ac.DNSStats = c.DNSStats @@ -1484,6 +1485,7 @@ func (ns *networkState) mergeConnectionStats(a, b *ConnectionStats) (collision b } a.ProtocolStack.MergeWith(b.ProtocolStack) + a.TLSTags.MergeWith(b.TLSTags) return false } diff --git a/pkg/network/tracer/connection/ebpf_tracer.go b/pkg/network/tracer/connection/ebpf_tracer.go index 8e66b467c0b25e..833d5a237dcbd3 100644 --- a/pkg/network/tracer/connection/ebpf_tracer.go +++ b/pkg/network/tracer/connection/ebpf_tracer.go @@ -32,6 +32,7 @@ import ( netebpf "github.com/DataDog/datadog-agent/pkg/network/ebpf" "github.com/DataDog/datadog-agent/pkg/network/ebpf/probes" "github.com/DataDog/datadog-agent/pkg/network/protocols" + "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/failure" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/fentry" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/kprobe" @@ -763,7 +764,11 @@ func populateConnStats(stats *network.ConnectionStats, t *netebpf.ConnTuple, s * Encryption: protocols.Encryption(s.Protocol_stack.Encryption), } - stats.TLSTags = s.Tls_tags + stats.TLSTags = tls.Tags{ + ChosenVersion: s.Tls_tags.Chosen_version, + CipherSuite: s.Tls_tags.Cipher_suite, + OfferedVersions: s.Tls_tags.Offered_versions, + } if t.Type() == netebpf.TCP { stats.Type = network.TCP diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index dae63cc7ed8970..4a6e7fc0220270 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -40,6 +40,7 @@ import ( "github.com/DataDog/datadog-agent/pkg/network/config" "github.com/DataDog/datadog-agent/pkg/network/protocols" usmtestutil "github.com/DataDog/datadog-agent/pkg/network/protocols/http/testutil" + ddtls "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/kprobe" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns" @@ -1457,9 +1458,8 @@ func (s *TracerSuite) TestTLSClassification() { if !kprobe.ClassificationSupported(cfg) { t.Skip("TLS classification platform not supported") } - //testutil.GetFreePort() - port := uint16(443) - //require.NoError(t, err) + port, err := testutil.GetFreePort() + require.NoError(t, err) portAsString := strconv.Itoa(int(port)) tr := setupTracer(t, cfg) @@ -1471,7 +1471,6 @@ func (s *TracerSuite) TestTLSClassification() { } tests := make([]tlsTest, 0) for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { - //for _, scenario := range []uint16{tls.VersionTLS12} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), @@ -1506,11 +1505,15 @@ func (s *TracerSuite) TestTLSClassification() { require.NoError(t, tlsConn.Handshake()) }, validation: func(t *testing.T, tr *Tracer) { - // Iterate through active connections until we find connection created above require.Eventuallyf(t, func() bool { payload := getConnections(t, tr) for _, c := range payload.Conns { - if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) { + if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { + expectedTagKey := ddtls.TagTLSVersion + tls.VersionName(scenario) + tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) + if _, ok := tlsTags[expectedTagKey]; !ok { + return false + } return true } } From 9e3107480539072c944367273a23c8454e05960c Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 12 Nov 2024 16:43:40 -0500 Subject: [PATCH 10/53] add comments --- pkg/network/protocols/tls/types.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index ccffa36f7b6b4f..ca57fa10d7bc73 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -57,12 +57,15 @@ const ( tagTLSClientVersion = "tls.client_version:" ) +// Tags holds the TLS tags. It is used to store the TLS version, cipher suite and offered versions. +// We can't use the struct from eBPF as the definition is shared with windows. type Tags struct { ChosenVersion uint16 CipherSuite uint16 OfferedVersions uint8 } +// MergeWith merges the tags from another Tags struct into this one func (t *Tags) MergeWith(that Tags) { if t.ChosenVersion == 0 { t.ChosenVersion = that.ChosenVersion @@ -76,10 +79,12 @@ func (t *Tags) MergeWith(that Tags) { } +// IsEmpty returns true if all fields are zero func (t *Tags) IsEmpty() bool { return t.ChosenVersion == 0 && t.CipherSuite == 0 && t.OfferedVersions == 0 } +// String returns a string representation of the Tags struct func (t *Tags) String() string { return fmt.Sprintf("ChosenVersion: %d, CipherSuite: %d, OfferedVersions: %d", t.ChosenVersion, t.CipherSuite, t.OfferedVersions) } From ad4d31ca01c3146f5e6301360a661d48c2a9f09b Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 12 Nov 2024 16:45:39 -0500 Subject: [PATCH 11/53] package comment --- pkg/network/protocols/tls/types.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index ca57fa10d7bc73..1ea3fe85d46fba 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -3,6 +3,7 @@ // This product includes software developed at Datadog (https://www.datadoghq.com/). // Copyright 2024-present Datadog, Inc. +// Package tls contains definitions and methods related to tags parsed from the TLS handshake package tls import "fmt" From 0cee274193d7846d32a0f5fa56179a915eb33621 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 12 Nov 2024 17:33:29 -0500 Subject: [PATCH 12/53] cleanup --- .../classification/protocol-classification.h | 6 +--- .../classification/shared-tracer-maps.h | 3 -- .../ebpf/c/protocols/tls/tls-helpers.h | 34 +++++++++++++++++++ pkg/network/ebpf/c/protocols/tls/tls.h | 13 ++++--- pkg/network/ebpf/c/tracer/stats.h | 13 +------ 5 files changed, 42 insertions(+), 27 deletions(-) create mode 100644 pkg/network/ebpf/c/protocols/tls/tls-helpers.h diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index f01dd3e9e72ad9..9d00aa965d5dab 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -178,11 +178,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct // Parse TLS payload tls_info_t *tags = get_or_create_tls_enhanced_tags(&skb_tup); if (tags) { - // Parse the TLS payload and update the tags - int ret = parse_tls_payload(skb, skb_info.data_off, &tls_hdr, tags); - log_debug("adamk\n"); - log_debug("adamk tls classification: parse_tls_payload=%d", ret); - ret++; + parse_tls_payload(skb, skb_info.data_off, &tls_hdr, tags); } // The connection is TLS encrypted, thus we cannot further classify the protocol // using the socket filter and can bail out; diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index 61546eb38e401c..bcd132dd9d4918 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -30,11 +30,8 @@ static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t tls_info_t *tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); if (!tags) { - // Initialize a new entry tls_info_t empty_tags = {0}; bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags, BPF_NOEXIST); - - // Lookup again tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); } return tags; diff --git a/pkg/network/ebpf/c/protocols/tls/tls-helpers.h b/pkg/network/ebpf/c/protocols/tls/tls-helpers.h new file mode 100644 index 00000000000000..52d40a2a4afaea --- /dev/null +++ b/pkg/network/ebpf/c/protocols/tls/tls-helpers.h @@ -0,0 +1,34 @@ +#ifndef __TLS_HELPERS_H +#define __TLS_HELPERS_H + +#include "tracer/tracer.h" + +// Assume that a zero value for chosen_version and cipher_suite indicates "not set" +#define TLS_VERSION_UNSET 0 +#define CIPHER_SUITE_UNSET 0 + +// merge_tls_info modifies `this` by merging it with `that` +static __always_inline void merge_tls_info(tls_info_t *this, tls_info_t *that) { + if (!this || !that) { + return; + } + + // Merge chosen_version if not already set + if (this->chosen_version == TLS_VERSION_UNSET && that->chosen_version != TLS_VERSION_UNSET) { + this->chosen_version = that->chosen_version; + } + + // Merge cipher_suite if not already set + if (this->cipher_suite == CIPHER_SUITE_UNSET && that->cipher_suite != CIPHER_SUITE_UNSET) { + this->cipher_suite = that->cipher_suite; + } + + // Merge offered_versions bitmask using bitwise OR + this->offered_versions |= that->offered_versions; + + // Merge reserved field if necessary (depending on your use case) + // For now, we can choose to keep it as is or apply specific logic + // this->reserved |= that->reserved; // Uncomment if needed +} + +#endif \ No newline at end of file diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index c68275cc3951a3..098c957975775d 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -5,8 +5,6 @@ #include "bpf_builtins.h" #include "tracer/tracer.h" -#define ETH_HLEN 14 // Ethernet header length - #define SSL_VERSION20 0x0200 #define SSL_VERSION30 0x0300 #define TLS_VERSION10 0x0301 @@ -33,7 +31,7 @@ typedef struct { #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 -// Function to check if the given version is a valid TLS version +// is_valid_tls_version checks if the version is a valid TLS version static __always_inline bool is_valid_tls_version(__u16 version) { switch (version) { case SSL_VERSION20: @@ -48,6 +46,7 @@ static __always_inline bool is_valid_tls_version(__u16 version) { } } +// set_tls_offered_version sets the bit corresponding to the offered version in the offered_versions field of tls_info static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 version) { switch (version) { case SSL_VERSION20: @@ -73,7 +72,7 @@ static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 } } -// Helper function to read and validate the TLS record header +// read_tls_record_header reads the TLS record header from the packet static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { __u32 skb_len = skb->len; @@ -102,7 +101,7 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 return true; } -// Function to check if the packet is TLS +// is_tls checks if the packet is a TLS packet and reads the TLS record header static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { // Use the helper function to read and validate the TLS record header if (!read_tls_record_header(skb, nh_off, tls_hdr)) @@ -251,7 +250,7 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse return 0; } -// Function to parse ServerHello message +// parse_server_hello reads the ServerHello message from the TLS handshake and populates select tags static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { // Move offset past handshake type (1 byte) offset += 1; @@ -375,7 +374,7 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse return 0; } -// Function to parse the TLS payload and update tls_enhanced_tags_t +// parse_tls_payload parses the TLS payload and populates select tags static __always_inline int parse_tls_payload(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr, tls_info_t *tags) { // At this point, tls_hdr has already been validated and filled by is_tls() __u64 offset = nh_off + sizeof(tls_record_header_t); diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index 574073ac1da63e..81354f939af515 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -112,13 +112,7 @@ static __always_inline void update_protocol_classification_information(conn_tupl // lookup from new map and add info to the connection tls_info_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); if (tls_tags) { - // TODO: flush the tags to userspace - log_debug("adamk tls classification: client version 1=%d", tls_tags->offered_versions); - log_debug("adamk tls classification: server version=%d", tls_tags->chosen_version); - log_debug("adamk tls classification: server cipher=%d", tls_tags->cipher_suite); stats->tls_tags = *tls_tags; - } else { - log_debug("adamk tls classification: no tags found"); } conn_tuple_t *cached_skb_conn_tup_ptr = bpf_map_lookup_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple_copy); @@ -133,14 +127,9 @@ static __always_inline void update_protocol_classification_information(conn_tupl merge_protocol_stacks(&stats->protocol_stack, protocol_stack); // lookup from new map and add info to the connection tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); + // TODO: we should merge the tags if (tls_tags) { - // TODO: flush the tags to userspace - log_debug("adamk tls classification: client version 1=%d", tls_tags->offered_versions); - log_debug("adamk tls classification: server version=%d", tls_tags->chosen_version); - log_debug("adamk tls classification: server cipher=%d", tls_tags->cipher_suite); stats->tls_tags = *tls_tags; - } else { - log_debug("adamk tls classification 2: no tags found"); } } From 9f834c57bb81bc446510015c9b147d4405b4d8ea Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 13 Nov 2024 10:39:39 -0500 Subject: [PATCH 13/53] move helper --- .../ebpf/c/protocols/tls/tls-helpers.h | 34 ------------------- pkg/network/ebpf/c/tracer/stats.h | 33 +++++++++++++----- 2 files changed, 24 insertions(+), 43 deletions(-) delete mode 100644 pkg/network/ebpf/c/protocols/tls/tls-helpers.h diff --git a/pkg/network/ebpf/c/protocols/tls/tls-helpers.h b/pkg/network/ebpf/c/protocols/tls/tls-helpers.h deleted file mode 100644 index 52d40a2a4afaea..00000000000000 --- a/pkg/network/ebpf/c/protocols/tls/tls-helpers.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef __TLS_HELPERS_H -#define __TLS_HELPERS_H - -#include "tracer/tracer.h" - -// Assume that a zero value for chosen_version and cipher_suite indicates "not set" -#define TLS_VERSION_UNSET 0 -#define CIPHER_SUITE_UNSET 0 - -// merge_tls_info modifies `this` by merging it with `that` -static __always_inline void merge_tls_info(tls_info_t *this, tls_info_t *that) { - if (!this || !that) { - return; - } - - // Merge chosen_version if not already set - if (this->chosen_version == TLS_VERSION_UNSET && that->chosen_version != TLS_VERSION_UNSET) { - this->chosen_version = that->chosen_version; - } - - // Merge cipher_suite if not already set - if (this->cipher_suite == CIPHER_SUITE_UNSET && that->cipher_suite != CIPHER_SUITE_UNSET) { - this->cipher_suite = that->cipher_suite; - } - - // Merge offered_versions bitmask using bitwise OR - this->offered_versions |= that->offered_versions; - - // Merge reserved field if necessary (depending on your use case) - // For now, we can choose to keep it as is or apply specific logic - // this->reserved |= that->reserved; // Uncomment if needed -} - -#endif \ No newline at end of file diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index 81354f939af515..8166d07ac3180f 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -22,6 +22,26 @@ static __always_inline __u64 offset_rtt(); static __always_inline __u64 offset_rtt_var(); #endif +// merge_tls_info modifies `this` by merging it with `that` +static __always_inline void merge_tls_info(tls_info_t *this, tls_info_t *that) { + if (!this || !that) { + return; + } + + // Merge chosen_version if not already set + if (this->chosen_version == 0 && that->chosen_version != 0) { + this->chosen_version = that->chosen_version; + } + + // Merge cipher_suite if not already set + if (this->cipher_suite == 0 && that->cipher_suite != 0) { + this->cipher_suite = that->cipher_suite; + } + + // Merge offered_versions bitmask + this->offered_versions |= that->offered_versions; +} + static __always_inline conn_stats_ts_t *get_conn_stats(conn_tuple_t *t, struct sock *sk) { conn_stats_ts_t *cs = bpf_map_lookup_elem(&conn_stats, t); if (cs) { @@ -109,11 +129,9 @@ static __always_inline void update_protocol_classification_information(conn_tupl set_protocol_flag(protocol_stack, FLAG_NPM_ENABLED); mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); - // lookup from new map and add info to the connection + tls_info_t *tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); - if (tls_tags) { - stats->tls_tags = *tls_tags; - } + merge_tls_info(&stats->tls_tags, tls_tags); conn_tuple_t *cached_skb_conn_tup_ptr = bpf_map_lookup_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple_copy); if (!cached_skb_conn_tup_ptr) { @@ -125,12 +143,9 @@ static __always_inline void update_protocol_classification_information(conn_tupl set_protocol_flag(protocol_stack, FLAG_NPM_ENABLED); mark_protocol_direction(t, &conn_tuple_copy, protocol_stack); merge_protocol_stacks(&stats->protocol_stack, protocol_stack); - // lookup from new map and add info to the connection + tls_tags = get_tls_enhanced_tags(&conn_tuple_copy); - // TODO: we should merge the tags - if (tls_tags) { - stats->tls_tags = *tls_tags; - } + merge_tls_info(&stats->tls_tags, tls_tags); } static __always_inline void determine_connection_direction(conn_tuple_t *t, conn_stats_ts_t *conn_stats) { From c26c2293b744812323e8419cc94e9ff7f25a9c02 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 13 Nov 2024 17:54:07 -0500 Subject: [PATCH 14/53] unroll loops and fix windows build --- pkg/network/ebpf/c/protocols/tls/tls.h | 164 ++++++++++++------------ pkg/network/tracer/tracer_linux_test.go | 95 ++++++++++++++ pkg/network/tracer/tracer_test.go | 95 -------------- 3 files changed, 177 insertions(+), 177 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 098c957975775d..8a9ece57d0634c 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,6 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B +#define MAX_EXTENSIONS 16 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ @@ -190,62 +191,62 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse // Parse Extensions __u64 extensions_end = offset + extensions_length; - #define MAX_EXTENSIONS 16 __u8 extensions_parsed = 0; - while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { - // Read Extension Type (2 bytes) - __u16 extension_type; - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) - return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; - - // Read Extension Length (2 bytes) - __u16 extension_length; - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) - return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; - - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) - return -1; - - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (offset + 1 > skb_len) + #pragma unroll(MAX_EXTENSIONS) + while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { + // Read Extension Type (2 bytes) + __u16 extension_type; + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; - // Read list length (1 byte) - __u8 sv_list_length; - if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) + // Read Extension Length (2 bytes) + __u16 extension_length; + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) return -1; - offset += 1; + extension_length = bpf_ntohs(extension_length); + offset += 2; // Ensure we don't read beyond the packet - if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) return -1; - // Parse versions - __u8 num_versions = 0; - #define MAX_SUPPORTED_VERSIONS 6 - for (__u8 i = 0; i + 1 < sv_list_length && num_versions < MAX_SUPPORTED_VERSIONS; i += 2, num_versions++) { - __u16 sv_version; - if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + // Parse supported_versions extension + if (offset + 1 > skb_len) return -1; - sv_version = bpf_ntohs(sv_version); - offset += 2; - set_tls_offered_version(tags, sv_version); + + // Read list length (1 byte) + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) + return -1; + offset += 1; + + // Ensure we don't read beyond the packet + if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) + return -1; + + // Parse versions + __u8 num_versions = 0; + #define MAX_SUPPORTED_VERSIONS 6 + for (__u8 i = 0; i + 1 < sv_list_length && num_versions < MAX_SUPPORTED_VERSIONS; i += 2, num_versions++) { + __u16 sv_version; + if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) + return -1; + sv_version = bpf_ntohs(sv_version); + offset += 2; + set_tls_offered_version(tags, sv_version); + } + } else { + // Skip other extensions + offset += extension_length; } - } else { - // Skip other extensions - offset += extension_length; - } - extensions_parsed++; - } + extensions_parsed++; + } return 0; } @@ -323,51 +324,50 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse // Parse Extensions __u64 extensions_end = offset + extensions_length; - #define MAX_EXTENSIONS 16 __u8 extensions_parsed = 0; - - while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { - // Read Extension Type (2 bytes) - __u16 extension_type; - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) - return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; - - // Read Extension Length (2 bytes) - __u16 extension_length; - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) - return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; - - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) - return -1; - - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (extension_length != 2) - return -1; - - if (offset + 2 > skb_len) + #pragma unroll(MAX_EXTENSIONS) + while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { + // Read Extension Type (2 bytes) + __u16 extension_type; + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; - // Read selected version (2 bytes) - __u16 selected_version; - if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) + // Read Extension Length (2 bytes) + __u16 extension_length; + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) return -1; - selected_version = bpf_ntohs(selected_version); + extension_length = bpf_ntohs(extension_length); offset += 2; - tags->chosen_version = selected_version; - } else { - // Skip other extensions - offset += extension_length; - } + // Ensure we don't read beyond the packet + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) + return -1; - extensions_parsed++; + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + // Parse supported_versions extension + if (extension_length != 2) + return -1; + + if (offset + 2 > skb_len) + return -1; + + // Read selected version (2 bytes) + __u16 selected_version; + if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) + return -1; + selected_version = bpf_ntohs(selected_version); + offset += 2; + + tags->chosen_version = selected_version; + } else { + // Skip other extensions + offset += extension_length; + } + + extensions_parsed++; } } diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index eddeeeec174f2d..dc41cc01b2db11 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -11,6 +11,7 @@ import ( "bufio" "bytes" "context" + "crypto/tls" "errors" "fmt" "io" @@ -50,8 +51,12 @@ import ( "github.com/DataDog/datadog-agent/pkg/network/config/sysctl" "github.com/DataDog/datadog-agent/pkg/network/events" netlinktestutil "github.com/DataDog/datadog-agent/pkg/network/netlink/testutil" + "github.com/DataDog/datadog-agent/pkg/network/protocols" + usmtestutil "github.com/DataDog/datadog-agent/pkg/network/protocols/http/testutil" + ddtls "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/network/testutil" "github.com/DataDog/datadog-agent/pkg/network/tracer/connection" + "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/kprobe" "github.com/DataDog/datadog-agent/pkg/network/tracer/offsetguess" tracertestutil "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns" @@ -2547,6 +2552,96 @@ func setupDropTrafficRule(tb testing.TB) (ns string) { return } +func (s *TracerSuite) TestTLSClassification() { + t := s.T() + cfg := testConfig() + + if !kprobe.ClassificationSupported(cfg) { + t.Skip("TLS classification platform not supported") + } + port, err := tracertestutil.GetFreePort() + require.NoError(t, err) + portAsString := strconv.Itoa(int(port)) + + tr := setupTracer(t, cfg) + + type tlsTest struct { + name string + postTracerSetup func(t *testing.T) + validation func(t *testing.T, tr *Tracer) + } + tests := make([]tlsTest, 0) + for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + scenario := scenario + tests = append(tests, tlsTest{ + name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), + postTracerSetup: func(t *testing.T) { + srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:"+portAsString, func(conn net.Conn) { + defer conn.Close() + // Echo back whatever is received + _, err := io.Copy(conn, conn) + if err != nil { + fmt.Printf("Failed to echo data: %v\n", err) + return + } + }, scenario) + done := make(chan struct{}) + require.NoError(t, srv.Run(done)) + t.Cleanup(func() { close(done) }) + tlsConfig := &tls.Config{ + MinVersion: scenario, + MaxVersion: scenario, + InsecureSkipVerify: true, + SessionTicketsDisabled: true, // Disable session tickets + ClientSessionCache: nil, // Disable session cache + } + conn, err := net.Dial("tcp", "localhost:"+portAsString) + require.NoError(t, err) + defer conn.Close() + + // Wrap the TCP connection with TLS + tlsConn := tls.Client(conn, tlsConfig) + + // Perform the TLS handshake + require.NoError(t, tlsConn.Handshake()) + }, + validation: func(t *testing.T, tr *Tracer) { + require.Eventuallyf(t, func() bool { + payload := getConnections(t, tr) + for _, c := range payload.Conns { + if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { + expectedTagKey := ddtls.TagTLSVersion + tls.VersionName(scenario) + tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) + if _, ok := tlsTags[expectedTagKey]; !ok { + return false + } + return true + } + } + return false + }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", portAsString) + }, + }) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if ebpftest.GetBuildMode() == ebpftest.Fentry { + t.Skip("protocol classification not supported for fentry tracer") + } + t.Cleanup(func() { tr.RemoveClient(clientID) }) + t.Cleanup(func() { _ = tr.Pause() }) + + tr.RemoveClient(clientID) + require.NoError(t, tr.RegisterClient(clientID)) + require.NoError(t, tr.Resume(), "enable probes - before post tracer") + tt.postTracerSetup(t) + require.NoError(t, tr.Pause(), "disable probes - after post tracer") + tt.validation(t, tr) + }) + } +} + func skipOnEbpflessNotSupported(t *testing.T, cfg *config.Config) { if cfg.EnableEbpfless { t.Skip("not supported on ebpf-less") diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index 09db84d7d3e8bd..c338661f7c02b2 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -11,7 +11,6 @@ import ( "bufio" "bytes" "context" - "crypto/tls" "encoding/json" "errors" "fmt" @@ -38,10 +37,6 @@ import ( "github.com/DataDog/datadog-agent/pkg/ebpf/ebpftest" "github.com/DataDog/datadog-agent/pkg/network" "github.com/DataDog/datadog-agent/pkg/network/config" - "github.com/DataDog/datadog-agent/pkg/network/protocols" - usmtestutil "github.com/DataDog/datadog-agent/pkg/network/protocols/http/testutil" - ddtls "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" - "github.com/DataDog/datadog-agent/pkg/network/tracer/connection/kprobe" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil" "github.com/DataDog/datadog-agent/pkg/network/tracer/testutil/testdns" "github.com/DataDog/datadog-agent/pkg/process/util" @@ -1456,93 +1451,3 @@ func BenchmarkGetActiveConnections(b *testing.B) { assert.Equal(b, uint32(1), conn.Last.TCPClosed) } } - -func (s *TracerSuite) TestTLSClassification() { - t := s.T() - cfg := testConfig() - - if !kprobe.ClassificationSupported(cfg) { - t.Skip("TLS classification platform not supported") - } - port, err := testutil.GetFreePort() - require.NoError(t, err) - portAsString := strconv.Itoa(int(port)) - - tr := setupTracer(t, cfg) - - type tlsTest struct { - name string - postTracerSetup func(t *testing.T) - validation func(t *testing.T, tr *Tracer) - } - tests := make([]tlsTest, 0) - for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { - scenario := scenario - tests = append(tests, tlsTest{ - name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), - postTracerSetup: func(t *testing.T) { - srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:"+portAsString, func(conn net.Conn) { - defer conn.Close() - // Echo back whatever is received - _, err := io.Copy(conn, conn) - if err != nil { - fmt.Printf("Failed to echo data: %v\n", err) - return - } - }, scenario) - done := make(chan struct{}) - require.NoError(t, srv.Run(done)) - t.Cleanup(func() { close(done) }) - tlsConfig := &tls.Config{ - MinVersion: scenario, - MaxVersion: scenario, - InsecureSkipVerify: true, - SessionTicketsDisabled: true, // Disable session tickets - ClientSessionCache: nil, // Disable session cache - } - conn, err := net.Dial("tcp", "localhost:"+portAsString) - require.NoError(t, err) - defer conn.Close() - - // Wrap the TCP connection with TLS - tlsConn := tls.Client(conn, tlsConfig) - - // Perform the TLS handshake - require.NoError(t, tlsConn.Handshake()) - }, - validation: func(t *testing.T, tr *Tracer) { - require.Eventuallyf(t, func() bool { - payload := getConnections(t, tr) - for _, c := range payload.Conns { - if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { - expectedTagKey := ddtls.TagTLSVersion + tls.VersionName(scenario) - tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) - if _, ok := tlsTags[expectedTagKey]; !ok { - return false - } - return true - } - } - return false - }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", portAsString) - }, - }) - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if ebpftest.GetBuildMode() == ebpftest.Fentry { - t.Skip("protocol classification not supported for fentry tracer") - } - t.Cleanup(func() { tr.RemoveClient(clientID) }) - t.Cleanup(func() { _ = tr.Pause() }) - - tr.RemoveClient(clientID) - require.NoError(t, tr.RegisterClient(clientID)) - require.NoError(t, tr.Resume(), "enable probes - before post tracer") - tt.postTracerSetup(t) - require.NoError(t, tr.Pause(), "disable probes - after post tracer") - tt.validation(t, tr) - }) - } -} From 9576add2f40e6377ef778db783e3640f3aa7d997 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 14 Nov 2024 11:14:13 -0500 Subject: [PATCH 15/53] convert to for loops --- pkg/network/ebpf/c/protocols/tls/tls.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 8a9ece57d0634c..b7db217f460915 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -194,7 +194,10 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse __u8 extensions_parsed = 0; #pragma unroll(MAX_EXTENSIONS) - while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + 4 > extensions_end) { + break; + } // Read Extension Type (2 bytes) __u16 extension_type; if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) @@ -326,7 +329,10 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse __u64 extensions_end = offset + extensions_length; __u8 extensions_parsed = 0; #pragma unroll(MAX_EXTENSIONS) - while (offset + 4 <= extensions_end && extensions_parsed < MAX_EXTENSIONS) { + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + 4 > extensions_end) { + break; + } // Read Extension Type (2 bytes) __u16 extension_type; if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) From 344993de34329509de761c2c2a8500d8407f347b Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 14 Nov 2024 13:19:17 -0500 Subject: [PATCH 16/53] reduce max extensions parsed --- pkg/network/ebpf/c/protocols/tls/tls.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index b7db217f460915..2834ac6cc59336 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 16 +#define MAX_EXTENSIONS 8 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ From d3b426cab74ab9a8831cfb12edb4eca01e8553b1 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 14 Nov 2024 20:11:07 -0500 Subject: [PATCH 17/53] further attempts to limit stack usage --- .../classification/protocol-classification.h | 2 +- pkg/network/ebpf/c/protocols/tls/tls.h | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 9d00aa965d5dab..f08e90832e83cf 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -171,7 +171,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct tls_record_header_t tls_hdr = {0}; - if (is_tls(skb, skb_info.data_off, &tls_hdr)) { + if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, &tls_hdr)) { // TLS classification update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 2834ac6cc59336..f486ecc005ba9e 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 8 +#define MAX_EXTENSIONS 6 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ @@ -143,6 +143,14 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse // Store client_version in tags (in case supported_versions extension is absent) set_tls_offered_version(tags, client_version); + if (client_version != TLS_VERSION12) { + // if the version is less than 1.2, there won't be any extensions and we can stop here + return 0; + } + + // Check if there are extensions if the version is listed as TLS 1.2, as this + // version may actually be 1.3 and the real version is in the extensions + // Skip Random (32 bytes) offset += 32; @@ -310,7 +318,13 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse // Store parsed data into tags tags->cipher_suite = cipher_suite; - // Check if there are extensions + if (tags->chosen_version != TLS_VERSION12) { + // if the version is less than 1.2, there won't be any extensions and we can stop here + return 0; + } + + // Check if there are extensions if the version is listed as TLS 1.2, as this + // version may actually be 1.3 and the real version is in the extensions if (offset < handshake_end) { // Read Extensions Length (2 bytes) if (offset + 2 > skb_len || offset + 2 > handshake_end) From 3b197d9dbd0f9a060c2085408318efbc1dddec38 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 18 Nov 2024 12:16:24 -0500 Subject: [PATCH 18/53] comment --- pkg/network/ebpf/probes/probes.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/network/ebpf/probes/probes.go b/pkg/network/ebpf/probes/probes.go index a791854fd3fb5d..4075e6f7616181 100644 --- a/pkg/network/ebpf/probes/probes.go +++ b/pkg/network/ebpf/probes/probes.go @@ -234,7 +234,8 @@ const ( ConnectionProtocolMap BPFMapName = "connection_protocol" // ConnectionTupleToSocketSKBConnMap is the map storing the connection tuple to socket skb conn tuple ConnectionTupleToSocketSKBConnMap BPFMapName = "conn_tuple_to_socket_skb_conn_tuple" - EnhancedTLSTagsMap BPFMapName = "tls_enhanced_tags" + // EnhancedTLSTagsMap is the map storing additional tags for TLS connections (version, cipher, etc.) + EnhancedTLSTagsMap BPFMapName = "tls_enhanced_tags" // ClassificationProgsMap is the map storing the programs to run on classification events ClassificationProgsMap BPFMapName = "classification_progs" // TCPCloseProgsMap is the map storing the programs to run on TCP close events From 0889697c12de7981a253f077ab0f1d14ac83a5eb Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 18 Nov 2024 23:28:02 -0500 Subject: [PATCH 19/53] test with 2 loops to validate old kernels --- pkg/network/ebpf/c/protocols/tls/tls.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index f486ecc005ba9e..fe4709861e3d82 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 6 +#define MAX_EXTENSIONS 2 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ From 708d6b1272020ac4e9d757c189afc35eb7e5cc9c Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 19 Nov 2024 10:14:53 -0500 Subject: [PATCH 20/53] try defining loop variables outside unroll --- pkg/network/ebpf/c/protocols/tls/tls.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index fe4709861e3d82..2cdc59b0b0a584 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -200,6 +200,11 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse // Parse Extensions __u64 extensions_end = offset + extensions_length; __u8 extensions_parsed = 0; + __u16 extension_type; + __u16 extension_length; + __u8 sv_list_length; + __u8 num_versions = 0; + __u16 sv_version; #pragma unroll(MAX_EXTENSIONS) for (int i = 0; i < MAX_EXTENSIONS; i++) { @@ -207,14 +212,12 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse break; } // Read Extension Type (2 bytes) - __u16 extension_type; if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) return -1; extension_type = bpf_ntohs(extension_type); offset += 2; // Read Extension Length (2 bytes) - __u16 extension_length; if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) return -1; extension_length = bpf_ntohs(extension_length); @@ -231,7 +234,6 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse return -1; // Read list length (1 byte) - __u8 sv_list_length; if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) return -1; offset += 1; @@ -241,10 +243,8 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse return -1; // Parse versions - __u8 num_versions = 0; #define MAX_SUPPORTED_VERSIONS 6 for (__u8 i = 0; i + 1 < sv_list_length && num_versions < MAX_SUPPORTED_VERSIONS; i += 2, num_versions++) { - __u16 sv_version; if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) return -1; sv_version = bpf_ntohs(sv_version); @@ -342,20 +342,21 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse // Parse Extensions __u64 extensions_end = offset + extensions_length; __u8 extensions_parsed = 0; + __u16 extension_type; + __u16 extension_length; + __u16 selected_version; #pragma unroll(MAX_EXTENSIONS) for (int i = 0; i < MAX_EXTENSIONS; i++) { if (offset + 4 > extensions_end) { break; } // Read Extension Type (2 bytes) - __u16 extension_type; if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) return -1; extension_type = bpf_ntohs(extension_type); offset += 2; // Read Extension Length (2 bytes) - __u16 extension_length; if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) return -1; extension_length = bpf_ntohs(extension_length); @@ -375,7 +376,6 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse return -1; // Read selected version (2 bytes) - __u16 selected_version; if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) return -1; selected_version = bpf_ntohs(selected_version); From 69e0d3d1e7a168bda414d78f64dc98ff4de8e085 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 19 Nov 2024 16:02:25 -0500 Subject: [PATCH 21/53] unroll inner loop --- pkg/network/ebpf/c/protocols/tls/tls.h | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 2cdc59b0b0a584..d5677483a428aa 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -203,8 +203,6 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse __u16 extension_type; __u16 extension_length; __u8 sv_list_length; - __u8 num_versions = 0; - __u16 sv_version; #pragma unroll(MAX_EXTENSIONS) for (int i = 0; i < MAX_EXTENSIONS; i++) { @@ -242,14 +240,29 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) return -1; - // Parse versions #define MAX_SUPPORTED_VERSIONS 6 - for (__u8 i = 0; i + 1 < sv_list_length && num_versions < MAX_SUPPORTED_VERSIONS; i += 2, num_versions++) { + __u8 num_versions = 0; + __u8 i = 0; + __u16 sv_version; + + #pragma unroll(MAX_SUPPORTED_VERSIONS) + for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { + if (i + 1 >= sv_list_length) + break; + if (offset + 2 > skb_len) + return -1; + + // Load the supported version if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) return -1; sv_version = bpf_ntohs(sv_version); offset += 2; + + // Store the version set_tls_offered_version(tags, sv_version); + + num_versions++; + i += 2; } } else { // Skip other extensions From 3e65bac3e24c5a6ef0bc813390daea9c1dde8ede Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 19 Nov 2024 18:27:54 -0500 Subject: [PATCH 22/53] set max tags to 8 again --- pkg/network/ebpf/c/protocols/tls/tls.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index d5677483a428aa..1593404b761e9e 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 2 +#define MAX_EXTENSIONS 8 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ From aabd648901731f45f70cfddae5886734d7dc3bfe Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 19 Nov 2024 23:03:31 -0500 Subject: [PATCH 23/53] 6 extensions parsed --- pkg/network/ebpf/c/protocols/tls/tls.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 1593404b761e9e..1be92efd91d294 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 8 +#define MAX_EXTENSIONS 6 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ From fc930a8703b2384bf717069064de1fca57ee443f Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 25 Nov 2024 22:07:10 -0500 Subject: [PATCH 24/53] stop unrolling loops --- .../classification/protocol-classification.h | 4 - pkg/network/ebpf/c/protocols/tls/tls.h | 178 +++++++++--------- 2 files changed, 89 insertions(+), 93 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index ef46a2b43d0bbd..2c6be8dc534ef0 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -179,10 +179,6 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct return; } - if(is_protocol_layer_known(protocol_stack, LAYER_ENCRYPTION)) { - return; - } - if (app_layer_proto != PROTOCOL_UNKNOWN && app_layer_proto != PROTOCOL_HTTP2) { goto next_program; } diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 475c72573a8c0a..4bb6215c774f19 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -204,74 +204,74 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse __u16 extension_length; __u8 sv_list_length; - #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (offset + 4 > extensions_end) { - break; - } - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) + // #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + 4 > extensions_end) { + break; + } + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) + return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; + + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) + return -1; + extension_length = bpf_ntohs(extension_length); + offset += 2; + + // Ensure we don't read beyond the packet + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) + return -1; + + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + // Parse supported_versions extension + if (offset + 1 > skb_len) return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) + // Read list length (1 byte) + if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; + offset += 1; // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) + if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) return -1; - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (offset + 1 > skb_len) - return -1; + #define MAX_SUPPORTED_VERSIONS 6 + __u8 num_versions = 0; + __u8 i = 0; + __u16 sv_version; - // Read list length (1 byte) - if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) + // #pragma unroll(MAX_SUPPORTED_VERSIONS) + for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { + if (i + 1 >= sv_list_length) + break; + if (offset + 2 > skb_len) return -1; - offset += 1; - // Ensure we don't read beyond the packet - if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) + // Load the supported version + if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) return -1; + sv_version = bpf_ntohs(sv_version); + offset += 2; - #define MAX_SUPPORTED_VERSIONS 6 - __u8 num_versions = 0; - __u8 i = 0; - __u16 sv_version; - - #pragma unroll(MAX_SUPPORTED_VERSIONS) - for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { - if (i + 1 >= sv_list_length) - break; - if (offset + 2 > skb_len) - return -1; - - // Load the supported version - if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) - return -1; - sv_version = bpf_ntohs(sv_version); - offset += 2; - - // Store the version - set_tls_offered_version(tags, sv_version); - - num_versions++; - i += 2; - } - } else { - // Skip other extensions - offset += extension_length; - } + // Store the version + set_tls_offered_version(tags, sv_version); - extensions_parsed++; + num_versions++; + i += 2; + } + } else { + // Skip other extensions + offset += extension_length; } + extensions_parsed++; + } + return 0; } @@ -358,49 +358,49 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse __u16 extension_type; __u16 extension_length; __u16 selected_version; - #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (offset + 4 > extensions_end) { - break; - } - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) - return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; + // #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + 4 > extensions_end) { + break; + } + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) + return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) - return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) + return -1; + extension_length = bpf_ntohs(extension_length); + offset += 2; - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) - return -1; + // Ensure we don't read beyond the packet + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) + return -1; - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (extension_length != 2) - return -1; + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + // Parse supported_versions extension + if (extension_length != 2) + return -1; - if (offset + 2 > skb_len) - return -1; + if (offset + 2 > skb_len) + return -1; - // Read selected version (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) - return -1; - selected_version = bpf_ntohs(selected_version); - offset += 2; + // Read selected version (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) + return -1; + selected_version = bpf_ntohs(selected_version); + offset += 2; - tags->chosen_version = selected_version; - } else { - // Skip other extensions - offset += extension_length; - } + tags->chosen_version = selected_version; + } else { + // Skip other extensions + offset += extension_length; + } - extensions_parsed++; + extensions_parsed++; } } From c9da0b3822643064c3704799356c942ae8d0803c Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 26 Nov 2024 10:27:29 -0500 Subject: [PATCH 25/53] raise min version --- pkg/network/tracer/connection/kprobe/tracer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/network/tracer/connection/kprobe/tracer.go b/pkg/network/tracer/connection/kprobe/tracer.go index 634642154e8818..1ed4ed1873b07f 100644 --- a/pkg/network/tracer/connection/kprobe/tracer.go +++ b/pkg/network/tracer/connection/kprobe/tracer.go @@ -42,7 +42,7 @@ var ( // The kernel has to be newer than 4.11.0 since we are using bpf_skb_load_bytes (4.5.0+), which // was added to socket filters in 4.11.0: // - 2492d3b867043f6880708d095a7a5d65debcfc32 - classificationMinimumKernel = kernel.VersionCode(4, 11, 0) + classificationMinimumKernel = kernel.VersionCode(4, 15, 0) protocolClassificationTailCalls = []manager.TailCallRoute{ { From af58525e5c36de81270a596f0a1f0dcfaa26d1f9 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 2 Dec 2024 14:41:49 -0500 Subject: [PATCH 26/53] fix routing and tagging, increase max extensions --- .../ebpf/c/protocols/classification/defs.h | 30 ++- .../classification/protocol-classification.h | 62 ++++- .../classification/routing-helpers.h | 22 +- .../ebpf/c/protocols/classification/routing.h | 6 +- .../classification/shared-tracer-maps.h | 16 +- .../protocols/classification/stack-helpers.h | 26 +- .../c/protocols/classification/usm-context.h | 1 + pkg/network/ebpf/c/protocols/tls/tls.h | 230 +++++++++--------- pkg/network/ebpf/c/tracer.c | 12 + pkg/network/ebpf/kprobe_types.go | 8 +- pkg/network/ebpf/kprobe_types_linux.go | 10 +- pkg/network/ebpf/probes/probes.go | 4 + .../tracer/connection/kprobe/config.go | 2 + .../tracer/connection/kprobe/manager.go | 2 + .../tracer/connection/kprobe/tracer.go | 22 +- pkg/network/tracer/tracer_linux_test.go | 6 +- 16 files changed, 286 insertions(+), 173 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/defs.h b/pkg/network/ebpf/c/protocols/classification/defs.h index f90c54b36fb82c..2a690bea9767f0 100644 --- a/pkg/network/ebpf/c/protocols/classification/defs.h +++ b/pkg/network/ebpf/c/protocols/classification/defs.h @@ -21,13 +21,13 @@ // The maximum number of protocols per stack layer #define MAX_ENTRIES_PER_LAYER 255 -#define LAYER_API_BIT (1 << 13) +#define LAYER_ENCRYPTION_BIT (1 << 13) #define LAYER_APPLICATION_BIT (1 << 14) -#define LAYER_ENCRYPTION_BIT (1 << 15) +#define LAYER_API_BIT (1 << 15) -#define LAYER_API_MAX (LAYER_API_BIT + MAX_ENTRIES_PER_LAYER) -#define LAYER_APPLICATION_MAX (LAYER_APPLICATION_BIT + MAX_ENTRIES_PER_LAYER) #define LAYER_ENCRYPTION_MAX (LAYER_ENCRYPTION_BIT + MAX_ENTRIES_PER_LAYER) +#define LAYER_APPLICATION_MAX (LAYER_APPLICATION_BIT + MAX_ENTRIES_PER_LAYER) +#define LAYER_API_MAX (LAYER_API_BIT + MAX_ENTRIES_PER_LAYER) #define FLAG_FULLY_CLASSIFIED 1 << 0 #define FLAG_USM_ENABLED 1 << 1 @@ -48,6 +48,11 @@ typedef enum { PROTOCOL_UNKNOWN = 0, + __LAYER_ENCRYPTION_MIN = LAYER_ENCRYPTION_BIT, + // Add encryption protocols below (eg. TLS) + PROTOCOL_TLS, + __LAYER_ENCRYPTION_MAX = LAYER_ENCRYPTION_MAX, + __LAYER_API_MIN = LAYER_API_BIT, // Add API protocols here (eg. gRPC) PROTOCOL_GRPC, @@ -65,10 +70,6 @@ typedef enum { PROTOCOL_MYSQL, __LAYER_APPLICATION_MAX = LAYER_APPLICATION_MAX, - __LAYER_ENCRYPTION_MIN = LAYER_ENCRYPTION_BIT, - // Add encryption protocols below (eg. TLS) - PROTOCOL_TLS, - __LAYER_ENCRYPTION_MAX = LAYER_ENCRYPTION_MAX, } __attribute__ ((packed)) protocol_t; // This enum represents all existing protocol layers @@ -76,19 +77,19 @@ typedef enum { // Each `protocol_t` entry is implicitly associated to a single // `protocol_layer_t` value (see notes above). // -//In order to determine which `protocol_layer_t` a `protocol_t` belongs to, +// In order to determine which `protocol_layer_t` a `protocol_t` belongs to, // users can call `get_protocol_layer` typedef enum { LAYER_UNKNOWN, + LAYER_ENCRYPTION, LAYER_API, LAYER_APPLICATION, - LAYER_ENCRYPTION, } __attribute__ ((packed)) protocol_layer_t; typedef struct { + __u8 layer_encryption; __u8 layer_api; __u8 layer_application; - __u8 layer_encryption; __u8 flags; } protocol_stack_t; @@ -114,6 +115,10 @@ typedef struct { typedef enum { CLASSIFICATION_PROG_UNKNOWN = 0, + __PROG_ENCRYPTION, + // Encryption classification programs go here + CLASSIFICATION_TLS_CLIENT_PROG, + CLASSIFICATION_TLS_SERVER_PROG, __PROG_APPLICATION, // Application classification programs go here CLASSIFICATION_QUEUES_PROG, @@ -121,8 +126,7 @@ typedef enum { __PROG_API, // API classification programs go here CLASSIFICATION_GRPC_PROG, - __PROG_ENCRYPTION, - // Encryption classification programs go here + // Add before this value CLASSIFICATION_PROG_MAX, } classification_prog_t; diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 2c6be8dc534ef0..0596fe3a45778a 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -167,15 +167,18 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, &tls_hdr)) { // TLS classification - update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); + log_debug("adamk TLS classification started"); + // TODO: check if it's a TLS app data message. If so we're too late, label as TLS and bail from the classification + // update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); // Parse TLS payload - tls_info_t *tags = get_or_create_tls_enhanced_tags(&skb_tup); + tls_info_t *tags = get_or_create_tls_enhanced_tags(&usm_ctx->tuple); if (tags) { - parse_tls_payload(skb, skb_info.data_off, &tls_hdr, tags); + usm_ctx->tls_header = tls_hdr; + // The connection is TLS encrypted, so trigger some tail calls + // to extract metadata from the payload + goto next_program; } - // The connection is TLS encrypted, thus we cannot further classify the protocol - // using the socket filter and can bail out; return; } @@ -203,6 +206,55 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct classification_next_program(skb, usm_ctx); } +__maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_handshake_client(struct __sk_buff *skb) { + usm_context_t *usm_ctx = usm_context(skb); + if (!usm_ctx) { + return; + } + tls_info_t* tls_info = get_tls_enhanced_tags(&usm_ctx->tuple); + if (!tls_info) { + return; + } + __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); + if (!is_tls_handshake_client_hello(skb, &usm_ctx->tls_header, offset)) { + goto next_program; + } + parse_client_hello(skb, offset, skb->len, tls_info); + +next_program: + classification_next_program(skb, usm_ctx); +} + +__maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_handshake_server(struct __sk_buff *skb) { + usm_context_t *usm_ctx = usm_context(skb); + if (!usm_ctx) { + return; + } + tls_info_t* tls_info = get_tls_enhanced_tags(&usm_ctx->tuple); + if (!tls_info) { + return; + } + __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); + if (!is_tls_handshake_server_hello(skb, &usm_ctx->tls_header, offset)) { + goto next_program; + } + parse_server_hello(skb, offset, skb->len, tls_info); + + + protocol_stack_t *protocol_stack = get_protocol_stack(&usm_ctx->tuple); + if (!protocol_stack) { + return; + } + + log_debug("adamk TLS classification done, marking as fully classified"); + update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); + // mark_as_fully_classified(protocol_stack); + return; + +next_program: + classification_next_program(skb, usm_ctx); +} + __maybe_unused static __always_inline void protocol_classifier_entrypoint_queues(struct __sk_buff *skb) { usm_context_t *usm_ctx = usm_context(skb); if (!usm_ctx) { diff --git a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h index 9e2ba628851c6e..c06e6365c048c6 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h @@ -7,10 +7,10 @@ // has_available_program returns true when there is another program from within // the same protocol layer or false otherwise static __always_inline bool has_available_program(classification_prog_t current_program) { - classification_prog_t next_program = current_program+1; - if (next_program == __PROG_APPLICATION || + classification_prog_t next_program = current_program + 1; + if (next_program == __PROG_ENCRYPTION || + next_program == __PROG_APPLICATION || next_program == __PROG_API || - next_program == __PROG_ENCRYPTION || next_program == CLASSIFICATION_PROG_MAX) { return false; } @@ -42,18 +42,18 @@ static __always_inline bool has_available_program(classification_prog_t current_ // b) current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX #pragma clang diagnostic ignored "-Wtautological-overlap-compare" static __always_inline u16 get_current_program_layer(classification_prog_t current_program) { + if (current_program > __PROG_ENCRYPTION && current_program < __PROG_APPLICATION) { + return LAYER_ENCRYPTION_BIT; + } + if (current_program > __PROG_APPLICATION && current_program < __PROG_API) { return LAYER_APPLICATION_BIT; } - if (current_program > __PROG_API && current_program < __PROG_ENCRYPTION) { + if (current_program > __PROG_API && current_program < CLASSIFICATION_PROG_MAX) { return LAYER_API_BIT; } - if (current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX) { - return LAYER_ENCRYPTION_BIT; - } - return 0; } #pragma clang diagnostic pop @@ -61,15 +61,15 @@ static __always_inline u16 get_current_program_layer(classification_prog_t curre static __always_inline classification_prog_t next_layer_entrypoint(usm_context_t *usm_ctx) { u16 to_skip = usm_ctx->routing_skip_layers; + if (!(to_skip&LAYER_ENCRYPTION_BIT)) { + return __PROG_ENCRYPTION+1; + } if (!(to_skip&LAYER_APPLICATION_BIT)) { return __PROG_APPLICATION+1; } if (!(to_skip&LAYER_API_BIT)) { return __PROG_API+1; } - if (!(to_skip&LAYER_ENCRYPTION_BIT)) { - return __PROG_ENCRYPTION+1; - } return CLASSIFICATION_PROG_UNKNOWN; } diff --git a/pkg/network/ebpf/c/protocols/classification/routing.h b/pkg/network/ebpf/c/protocols/classification/routing.h index 8e2b092e0afba2..1033ed040be616 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing.h +++ b/pkg/network/ebpf/c/protocols/classification/routing.h @@ -58,15 +58,15 @@ static __always_inline void init_routing_cache(usm_context_t *usm_ctx, protocol_ // We skip a given layer in two cases: // 1) If the protocol for that layer is known // 2) If there are no programs registered for that layer + if (stack->layer_encryption || !has_available_program(__PROG_ENCRYPTION)) { + usm_ctx->routing_skip_layers |= LAYER_ENCRYPTION_BIT; + } if (stack->layer_application || !has_available_program(__PROG_APPLICATION)) { usm_ctx->routing_skip_layers |= LAYER_APPLICATION_BIT; } if (stack->layer_api || !has_available_program(__PROG_API)) { usm_ctx->routing_skip_layers |= LAYER_API_BIT; } - if (stack->layer_encryption || !has_available_program(__PROG_ENCRYPTION)) { - usm_ctx->routing_skip_layers |= LAYER_ENCRYPTION_BIT; - } } #endif diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index ca3a796d0385cf..da8a92b2b4074a 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -27,22 +27,24 @@ static __always_inline protocol_stack_t* __get_protocol_stack(conn_tuple_t* tupl } static __always_inline tls_info_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { - return bpf_map_lookup_elem(&tls_enhanced_tags, tuple); -} - -static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { conn_tuple_t normalized_tup = *tuple; normalize_tuple(&normalized_tup); + return bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); +} - tls_info_t *tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); +static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { + tls_info_t *tags = get_tls_enhanced_tags(tuple); if (!tags) { - tls_info_t empty_tags = {0}; - bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags, BPF_NOEXIST); + conn_tuple_t normalized_tup = *tuple; + normalize_tuple(&normalized_tup); + tls_info_t empty_tags = {.reserved = 1}; + bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags, BPF_ANY); tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); } return tags; } + static __always_inline protocol_stack_t* get_protocol_stack(conn_tuple_t *skb_tup) { conn_tuple_t normalized_tup = *skb_tup; normalize_tuple(&normalized_tup); diff --git a/pkg/network/ebpf/c/protocols/classification/stack-helpers.h b/pkg/network/ebpf/c/protocols/classification/stack-helpers.h index 5cba1cdf8e30ac..6b934c0c7b005c 100644 --- a/pkg/network/ebpf/c/protocols/classification/stack-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/stack-helpers.h @@ -9,15 +9,15 @@ // get_protocol_layer(PROTOCOL_HTTP) => LAYER_APPLICATION // get_protocol_layer(PROTOCOL_TLS) => LAYER_ENCRYPTION static __always_inline protocol_layer_t get_protocol_layer(protocol_t proto) { - u16 layer_bit = proto&(LAYER_API_BIT|LAYER_APPLICATION_BIT|LAYER_ENCRYPTION_BIT); + u16 layer_bit = proto&(LAYER_ENCRYPTION_BIT|LAYER_API_BIT|LAYER_APPLICATION_BIT); switch(layer_bit) { + case LAYER_ENCRYPTION_BIT: + return LAYER_ENCRYPTION; case LAYER_API_BIT: return LAYER_API; case LAYER_APPLICATION_BIT: return LAYER_APPLICATION; - case LAYER_ENCRYPTION_BIT: - return LAYER_ENCRYPTION; } return LAYER_UNKNOWN; @@ -37,15 +37,15 @@ static __always_inline void set_protocol(protocol_stack_t *stack, protocol_t pro // this is the the number of the protocol without the layer bit set __u8 proto_num = (__u8)proto; switch(layer) { + case LAYER_ENCRYPTION: + stack->layer_encryption = proto_num; + return; case LAYER_API: stack->layer_api = proto_num; return; case LAYER_APPLICATION: stack->layer_application = proto_num; return; - case LAYER_ENCRYPTION: - stack->layer_encryption = proto_num; - return; default: return; } @@ -92,6 +92,10 @@ __maybe_unused static __always_inline protocol_t get_protocol_from_stack(protoco __u16 proto_num = 0; __u16 layer_bit = 0; switch(layer) { + case LAYER_ENCRYPTION: + proto_num = stack->layer_encryption; + layer_bit = LAYER_ENCRYPTION_BIT; + break; case LAYER_API: proto_num = stack->layer_api; layer_bit = LAYER_API_BIT; @@ -100,10 +104,6 @@ __maybe_unused static __always_inline protocol_t get_protocol_from_stack(protoco proto_num = stack->layer_application; layer_bit = LAYER_APPLICATION_BIT; break; - case LAYER_ENCRYPTION: - proto_num = stack->layer_encryption; - layer_bit = LAYER_ENCRYPTION_BIT; - break; default: break; } @@ -131,15 +131,15 @@ static __always_inline void merge_protocol_stacks(protocol_stack_t *this, protoc return; } + if (!this->layer_encryption) { + this->layer_encryption = that->layer_encryption; + } if (!this->layer_api) { this->layer_api = that->layer_api; } if (!this->layer_application) { this->layer_application = that->layer_application; } - if (!this->layer_encryption) { - this->layer_encryption = that->layer_encryption; - } this->flags |= that->flags; } diff --git a/pkg/network/ebpf/c/protocols/classification/usm-context.h b/pkg/network/ebpf/c/protocols/classification/usm-context.h index 47f637c1291cb8..29d2c420060cd5 100644 --- a/pkg/network/ebpf/c/protocols/classification/usm-context.h +++ b/pkg/network/ebpf/c/protocols/classification/usm-context.h @@ -23,6 +23,7 @@ typedef struct { // bit mask with layers that should be skiped u16 routing_skip_layers; classification_prog_t routing_current_program; + tls_record_header_t tls_header; } usm_context_t; // Kernels before 4.7 do not know about per-cpu array maps. diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 4bb6215c774f19..63307d1b78fac7 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 3 +#define MAX_EXTENSIONS 16 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ @@ -204,74 +204,74 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse __u16 extension_length; __u8 sv_list_length; - // #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (offset + 4 > extensions_end) { - break; - } - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) - return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; - - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) - return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; - - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) - return -1; - - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (offset + 1 > skb_len) + #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + 4 > extensions_end) { + break; + } + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; - // Read list length (1 byte) - if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) return -1; - offset += 1; + extension_length = bpf_ntohs(extension_length); + offset += 2; // Ensure we don't read beyond the packet - if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) return -1; - #define MAX_SUPPORTED_VERSIONS 6 - __u8 num_versions = 0; - __u8 i = 0; - __u16 sv_version; - - // #pragma unroll(MAX_SUPPORTED_VERSIONS) - for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { - if (i + 1 >= sv_list_length) - break; - if (offset + 2 > skb_len) + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + // Parse supported_versions extension + if (offset + 1 > skb_len) return -1; - // Load the supported version - if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) + // Read list length (1 byte) + if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) return -1; - sv_version = bpf_ntohs(sv_version); - offset += 2; + offset += 1; - // Store the version - set_tls_offered_version(tags, sv_version); + // Ensure we don't read beyond the packet + if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) + return -1; - num_versions++; - i += 2; + #define MAX_SUPPORTED_VERSIONS 4 + __u8 num_versions = 0; + __u8 i = 0; + __u16 sv_version; + + #pragma unroll(MAX_SUPPORTED_VERSIONS) + for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { + if (i + 1 >= sv_list_length) + break; + if (offset + 2 > skb_len) + return -1; + + // Load the supported version + if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) + return -1; + sv_version = bpf_ntohs(sv_version); + offset += 2; + + // Store the version + set_tls_offered_version(tags, sv_version); + + num_versions++; + i += 2; + } + } else { + // Skip other extensions + offset += extension_length; } - } else { - // Skip other extensions - offset += extension_length; - } - - extensions_parsed++; - } + extensions_parsed++; + } + log_debug("adamk successfully parsed client hello message"); return 0; } @@ -358,75 +358,87 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse __u16 extension_type; __u16 extension_length; __u16 selected_version; - // #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (offset + 4 > extensions_end) { - break; - } - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) - return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; - - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) - return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; - - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) - return -1; - - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (extension_length != 2) + #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + 4 > extensions_end) { + break; + } + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) return -1; + extension_type = bpf_ntohs(extension_type); + offset += 2; - if (offset + 2 > skb_len) + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) return -1; + extension_length = bpf_ntohs(extension_length); + offset += 2; - // Read selected version (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) + // Ensure we don't read beyond the packet + if (offset + extension_length > skb_len || offset + extension_length > extensions_end) return -1; - selected_version = bpf_ntohs(selected_version); - offset += 2; - tags->chosen_version = selected_version; - } else { - // Skip other extensions - offset += extension_length; - } + // Check for supported_versions extension (type 43 or 0x002B) + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + // Parse supported_versions extension + if (extension_length != 2) + return -1; - extensions_parsed++; + if (offset + 2 > skb_len) + return -1; + + // Read selected version (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) + return -1; + selected_version = bpf_ntohs(selected_version); + offset += 2; + + tags->chosen_version = selected_version; + } else { + // Skip other extensions + offset += extension_length; + } + + extensions_parsed++; + } } - } + + log_debug("adamk successfully parsed server hello message"); return 0; } -// parse_tls_payload parses the TLS payload and populates select tags -static __always_inline int parse_tls_payload(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr, tls_info_t *tags) { - // At this point, tls_hdr has already been validated and filled by is_tls() - __u64 offset = nh_off + sizeof(tls_record_header_t); +static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { + if (!tls_hdr) { + return false; + } + if (tls_hdr->content_type != TLS_HANDSHAKE) { + return false; + } - if (tls_hdr->content_type == TLS_HANDSHAKE) { - __u8 handshake_type; - if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) - return -1; + // Read handshake type + __u8 handshake_type; + if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) { + return false; + } + return handshake_type == TLS_HANDSHAKE_CLIENT_HELLO; +} - if (handshake_type == TLS_HANDSHAKE_CLIENT_HELLO) { - return parse_client_hello(skb, offset, skb->len, tags); - } else if (handshake_type == TLS_HANDSHAKE_SERVER_HELLO) { - return parse_server_hello(skb, offset, skb->len, tags); - } else { - return -1; - } - } else { - return -1; +static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { + if (!tls_hdr) { + return false; } + if (tls_hdr->content_type != TLS_HANDSHAKE) { + return false; + } + + // Read handshake type + __u8 handshake_type; + if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) + return false; + + return handshake_type == TLS_HANDSHAKE_SERVER_HELLO; } #endif // __TLS_H diff --git a/pkg/network/ebpf/c/tracer.c b/pkg/network/ebpf/c/tracer.c index e83d344906c302..59a32bfffeba7b 100644 --- a/pkg/network/ebpf/c/tracer.c +++ b/pkg/network/ebpf/c/tracer.c @@ -32,6 +32,18 @@ int socket__classifier_entry(struct __sk_buff *skb) { return 0; } +SEC("socket/classifier_tls_handshake_client") +int socket__classifier_tls_handshake_client(struct __sk_buff *skb) { + protocol_classifier_entrypoint_tls_handshake_client(skb); + return 0; +} + +SEC("socket/classifier_tls_handshake_server") +int socket__classifier_tls_handshake_server(struct __sk_buff *skb) { + protocol_classifier_entrypoint_tls_handshake_server(skb); + return 0; +} + SEC("socket/classifier_queues") int socket__classifier_queues(struct __sk_buff *skb) { protocol_classifier_entrypoint_queues(skb); diff --git a/pkg/network/ebpf/kprobe_types.go b/pkg/network/ebpf/kprobe_types.go index 110ee0a91dd9fd..663fdae87de547 100644 --- a/pkg/network/ebpf/kprobe_types.go +++ b/pkg/network/ebpf/kprobe_types.go @@ -64,7 +64,9 @@ const SizeofConn = C.sizeof_conn_t type ClassificationProgram = uint32 const ( - ClassificationQueues ClassificationProgram = C.CLASSIFICATION_QUEUES_PROG - ClassificationDBs ClassificationProgram = C.CLASSIFICATION_DBS_PROG - ClassificationGRPC ClassificationProgram = C.CLASSIFICATION_GRPC_PROG + ClassificationTLSClient ClassificationProgram = C.CLASSIFICATION_TLS_CLIENT_PROG + ClassificationTLSServer ClassificationProgram = C.CLASSIFICATION_TLS_SERVER_PROG + ClassificationQueues ClassificationProgram = C.CLASSIFICATION_QUEUES_PROG + ClassificationDBs ClassificationProgram = C.CLASSIFICATION_DBS_PROG + ClassificationGRPC ClassificationProgram = C.CLASSIFICATION_GRPC_PROG ) diff --git a/pkg/network/ebpf/kprobe_types_linux.go b/pkg/network/ebpf/kprobe_types_linux.go index bb283f59cc9dad..a5bb92fea6061d 100644 --- a/pkg/network/ebpf/kprobe_types_linux.go +++ b/pkg/network/ebpf/kprobe_types_linux.go @@ -94,9 +94,9 @@ type BindSyscallArgs struct { Sk uint64 } type ProtocolStack struct { + Encryption uint8 Api uint8 Application uint8 - Encryption uint8 Flags uint8 } type ProtocolStackWrapper struct { @@ -141,7 +141,9 @@ const SizeofConn = 0x78 type ClassificationProgram = uint32 const ( - ClassificationQueues ClassificationProgram = 0x2 - ClassificationDBs ClassificationProgram = 0x3 - ClassificationGRPC ClassificationProgram = 0x5 + ClassificationTLSClient ClassificationProgram = 0x2 + ClassificationTLSServer ClassificationProgram = 0x3 + ClassificationQueues ClassificationProgram = 0x5 + ClassificationDBs ClassificationProgram = 0x6 + ClassificationGRPC ClassificationProgram = 0x8 ) diff --git a/pkg/network/ebpf/probes/probes.go b/pkg/network/ebpf/probes/probes.go index 4075e6f7616181..25c93cb5090660 100644 --- a/pkg/network/ebpf/probes/probes.go +++ b/pkg/network/ebpf/probes/probes.go @@ -27,6 +27,10 @@ const ( // ProtocolClassifierEntrySocketFilter runs a classifier algorithm as a socket filter ProtocolClassifierEntrySocketFilter ProbeFuncName = "socket__classifier_entry" + // ProtocolClassifierTLSClientSocketFilter runs classification rules for the TLS client hello packet + ProtocolClassifierTLSClientSocketFilter ProbeFuncName = "socket__classifier_tls_handshake_client" + // ProtocolClassifierTLSServerSocketFilter runs classification rules for the TLS server hello packet + ProtocolClassifierTLSServerSocketFilter ProbeFuncName = "socket__classifier_tls_handshake_server" // ProtocolClassifierQueuesSocketFilter runs a classification rules for Queue protocols. ProtocolClassifierQueuesSocketFilter ProbeFuncName = "socket__classifier_queues" // ProtocolClassifierDBsSocketFilter runs a classification rules for DB protocols. diff --git a/pkg/network/tracer/connection/kprobe/config.go b/pkg/network/tracer/connection/kprobe/config.go index 880a2f0a5e8388..b7c6bb2ff99865 100644 --- a/pkg/network/tracer/connection/kprobe/config.go +++ b/pkg/network/tracer/connection/kprobe/config.go @@ -58,6 +58,8 @@ func enabledProbes(c *config.Config, runtimeTracer, coreTracer bool) (map[probes if c.CollectTCPv4Conns || c.CollectTCPv6Conns { if ClassificationSupported(c) { enableProbe(enabled, probes.ProtocolClassifierEntrySocketFilter) + enableProbe(enabled, probes.ProtocolClassifierTLSClientSocketFilter) + enableProbe(enabled, probes.ProtocolClassifierTLSServerSocketFilter) enableProbe(enabled, probes.ProtocolClassifierQueuesSocketFilter) enableProbe(enabled, probes.ProtocolClassifierDBsSocketFilter) enableProbe(enabled, probes.ProtocolClassifierGRPCSocketFilter) diff --git a/pkg/network/tracer/connection/kprobe/manager.go b/pkg/network/tracer/connection/kprobe/manager.go index ceb37144b3c8af..5f0f40efd71ff0 100644 --- a/pkg/network/tracer/connection/kprobe/manager.go +++ b/pkg/network/tracer/connection/kprobe/manager.go @@ -19,6 +19,8 @@ import ( var mainProbes = []probes.ProbeFuncName{ probes.NetDevQueue, probes.ProtocolClassifierEntrySocketFilter, + probes.ProtocolClassifierTLSClientSocketFilter, + probes.ProtocolClassifierTLSServerSocketFilter, probes.ProtocolClassifierQueuesSocketFilter, probes.ProtocolClassifierDBsSocketFilter, probes.ProtocolClassifierGRPCSocketFilter, diff --git a/pkg/network/tracer/connection/kprobe/tracer.go b/pkg/network/tracer/connection/kprobe/tracer.go index 1ed4ed1873b07f..56f282b5e20346 100644 --- a/pkg/network/tracer/connection/kprobe/tracer.go +++ b/pkg/network/tracer/connection/kprobe/tracer.go @@ -42,9 +42,25 @@ var ( // The kernel has to be newer than 4.11.0 since we are using bpf_skb_load_bytes (4.5.0+), which // was added to socket filters in 4.11.0: // - 2492d3b867043f6880708d095a7a5d65debcfc32 - classificationMinimumKernel = kernel.VersionCode(4, 15, 0) + classificationMinimumKernel = kernel.VersionCode(4, 11, 0) protocolClassificationTailCalls = []manager.TailCallRoute{ + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSClient, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSClientSocketFilter, + UID: probeUID, + }, + }, + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSServer, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSServerSocketFilter, + UID: probeUID, + }, + }, { ProgArrayName: probes.ClassificationProgsMap, Key: netebpf.ClassificationQueues, @@ -90,8 +106,8 @@ var ( ) // ClassificationSupported returns true if the current kernel version supports the classification feature. -// The kernel has to be newer than 4.7.0 since we are using bpf_skb_load_bytes (4.5.0+) method to read from the socket -// filter, and a tracepoint (4.7.0+) +// The kernel has to be newer than 4.11.0 since we are using bpf_skb_load_bytes (4.5.0+) method to read from the socket +// filter which was added in 4.11, and a tracepoint (4.7.0+) func ClassificationSupported(config *config.Config) bool { if !config.ProtocolClassificationEnabled { return false diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 54398001b016cf..43af108618f67c 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2570,8 +2570,9 @@ func (s *TracerSuite) TestTLSClassification() { if !kprobe.ClassificationSupported(cfg) { t.Skip("TLS classification platform not supported") } - port, err := tracertestutil.GetFreePort() - require.NoError(t, err) + //port, err := tracertestutil.GetFreePort() + port := uint16(44957) + //require.NoError(t, err) portAsString := strconv.Itoa(int(port)) tr := setupTracer(t, cfg) @@ -2583,6 +2584,7 @@ func (s *TracerSuite) TestTLSClassification() { } tests := make([]tlsTest, 0) for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + //for _, scenario := range []uint16{tls.VersionTLS12} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), From 91b33df41c0466096b072a9af665ae88d5ba1d02 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 2 Dec 2024 19:38:02 -0500 Subject: [PATCH 27/53] bail early for appdata --- .../classification/protocol-classification.h | 31 +++++++++++-------- .../ebpf/c/protocols/classification/routing.h | 5 +-- pkg/network/ebpf/c/protocols/tls/tls.h | 6 ++-- pkg/network/tracer/tracer_linux_test.go | 7 +++-- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 0596fe3a45778a..d3c7350d3ef625 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -63,7 +63,9 @@ // updates the the protocol stack and adds the current layer to the routing skip list static __always_inline void update_protocol_information(usm_context_t *usm_ctx, protocol_stack_t *stack, protocol_t proto) { set_protocol(stack, proto); - usm_ctx->routing_skip_layers |= proto; + if (proto != PROTOCOL_TLS) { + usm_ctx->routing_skip_layers |= proto; + } } // Check if the connections is used for gRPC traffic. @@ -165,17 +167,20 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct tls_record_header_t tls_hdr = {0}; + // TLS classification if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, &tls_hdr)) { - // TLS classification - log_debug("adamk TLS classification started"); - // TODO: check if it's a TLS app data message. If so we're too late, label as TLS and bail from the classification - // update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); + update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); + if (tls_hdr.content_type == TLS_APPLICATION_DATA) { + // We can't classify TLS encrypted traffic further, so we mark the stack as fully classified + mark_as_fully_classified(protocol_stack); + return; + } // Parse TLS payload tls_info_t *tags = get_or_create_tls_enhanced_tags(&usm_ctx->tuple); if (tags) { usm_ctx->tls_header = tls_hdr; - // The connection is TLS encrypted, so trigger some tail calls + // The packet is a TLS handshake, so trigger some tail calls // to extract metadata from the payload goto next_program; } @@ -219,7 +224,9 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!is_tls_handshake_client_hello(skb, &usm_ctx->tls_header, offset)) { goto next_program; } - parse_client_hello(skb, offset, skb->len, tls_info); + if (parse_client_hello(skb, offset, skb->len, tls_info) != 0) { + return; + } next_program: classification_next_program(skb, usm_ctx); @@ -238,17 +245,15 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!is_tls_handshake_server_hello(skb, &usm_ctx->tls_header, offset)) { goto next_program; } - parse_server_hello(skb, offset, skb->len, tls_info); - + if (parse_server_hello(skb, offset, skb->len, tls_info) != 0) { + return; + } protocol_stack_t *protocol_stack = get_protocol_stack(&usm_ctx->tuple); if (!protocol_stack) { return; } - - log_debug("adamk TLS classification done, marking as fully classified"); - update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); - // mark_as_fully_classified(protocol_stack); + mark_as_fully_classified(protocol_stack); return; next_program: diff --git a/pkg/network/ebpf/c/protocols/classification/routing.h b/pkg/network/ebpf/c/protocols/classification/routing.h index 1033ed040be616..1320bdaab7318e 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing.h +++ b/pkg/network/ebpf/c/protocols/classification/routing.h @@ -56,9 +56,10 @@ static __always_inline void init_routing_cache(usm_context_t *usm_ctx, protocol_ usm_ctx->routing_current_program = CLASSIFICATION_PROG_UNKNOWN; // We skip a given layer in two cases: - // 1) If the protocol for that layer is known + // 1) If the protocol for that layer is known, + // except for encryption as it still needs to be traversed for metadata // 2) If there are no programs registered for that layer - if (stack->layer_encryption || !has_available_program(__PROG_ENCRYPTION)) { + if (stack->flags == FLAG_FULLY_CLASSIFIED || !has_available_program(__PROG_ENCRYPTION)) { usm_ctx->routing_skip_layers |= LAYER_ENCRYPTION_BIT; } if (stack->layer_application || !has_available_program(__PROG_APPLICATION)) { diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 63307d1b78fac7..020e3d8215e2cf 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 16 +#define MAX_EXTENSIONS 10 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ @@ -271,7 +271,7 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse extensions_parsed++; } - log_debug("adamk successfully parsed client hello message"); + // log_debug("adamk successfully parsed client hello message"); return 0; } @@ -404,7 +404,7 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse } } - log_debug("adamk successfully parsed server hello message"); + // log_debug("adamk successfully parsed server hello message"); return 0; } diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 43af108618f67c..01ad7ddb7a9cd5 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2570,9 +2570,9 @@ func (s *TracerSuite) TestTLSClassification() { if !kprobe.ClassificationSupported(cfg) { t.Skip("TLS classification platform not supported") } - //port, err := tracertestutil.GetFreePort() - port := uint16(44957) - //require.NoError(t, err) + port, err := tracertestutil.GetFreePort() + //port := uint16(44957) + require.NoError(t, err) portAsString := strconv.Itoa(int(port)) tr := setupTracer(t, cfg) @@ -2625,6 +2625,7 @@ func (s *TracerSuite) TestTLSClassification() { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { expectedTagKey := ddtls.TagTLSVersion + tls.VersionName(scenario) tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) + t.Log("TLS tags: ", tlsTags) if _, ok := tlsTags[expectedTagKey]; !ok { return false } From dc80849fa74e63898674a00c52f55ace9a122b99 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 2 Dec 2024 20:35:42 -0500 Subject: [PATCH 28/53] make test less brittle --- .../classification/protocol-classification.h | 2 +- pkg/network/protocols/tls/types.go | 8 +-- pkg/network/tracer/tracer_linux_test.go | 68 +++++++++++++++---- 3 files changed, 60 insertions(+), 18 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index d3c7350d3ef625..c0bc2a8f0b23db 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -154,7 +154,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct return; } - if (is_fully_classified(protocol_stack) ) { + if (is_fully_classified(protocol_stack)) { return; } diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 1ea3fe85d46fba..4a6c41a2d2a4ea 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -54,8 +54,8 @@ var offeredVersionBitmask = []struct { // Constants for tag keys const ( TagTLSVersion = "tls.version:" - tagTLSCipherSuiteID = "tls.cipher_suite_id:" - tagTLSClientVersion = "tls.client_version:" + TagTLSCipherSuiteID = "tls.cipher_suite_id:" + TagTLSClientVersion = "tls.client_version:" ) // Tags holds the TLS tags. It is used to store the TLS version, cipher suite and offered versions. @@ -125,12 +125,12 @@ func GetTLSDynamicTags(tls *Tags) map[string]struct{} { // Cipher suite ID as hex string if tls.CipherSuite != 0 { - tags[tagTLSCipherSuiteID+fmt.Sprintf("0x%04X", tls.CipherSuite)] = struct{}{} + tags[TagTLSCipherSuiteID+fmt.Sprintf("0x%04X", tls.CipherSuite)] = struct{}{} } // Client offered versions for _, versionName := range parseOfferedVersions(tls.OfferedVersions) { - tags[tagTLSClientVersion+versionName] = struct{}{} + tags[TagTLSClientVersion+versionName] = struct{}{} } return tags diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 01ad7ddb7a9cd5..a0e2281dee46f6 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2620,19 +2620,7 @@ func (s *TracerSuite) TestTLSClassification() { }, validation: func(t *testing.T, tr *Tracer) { require.Eventuallyf(t, func() bool { - payload := getConnections(t, tr) - for _, c := range payload.Conns { - if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { - expectedTagKey := ddtls.TagTLSVersion + tls.VersionName(scenario) - tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) - t.Log("TLS tags: ", tlsTags) - if _, ok := tlsTags[expectedTagKey]; !ok { - return false - } - return true - } - } - return false + return validateTLSTags(t, tr, port, scenario) }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", portAsString) }, }) @@ -2656,6 +2644,60 @@ func (s *TracerSuite) TestTLSClassification() { } } +func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) bool { + payload := getConnections(t, tr) + for _, c := range payload.Conns { + if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { + tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) + t.Log("TLS tags: ", tlsTags) + + // Check that the cipher suite ID tag is present + cipherSuiteTagFound := false + for key := range tlsTags { + if strings.HasPrefix(key, ddtls.TagTLSCipherSuiteID) { + cipherSuiteTagFound = true + break + } + } + if !cipherSuiteTagFound { + t.Log("Cipher suite ID tag missing") + return false + } + + // Check that the negotiated version tag is present + negotiatedVersionTag := ddtls.TagTLSVersion + tls.VersionName(scenario) + if _, ok := tlsTags[negotiatedVersionTag]; !ok { + t.Logf("Negotiated version tag '%s' not found", negotiatedVersionTag) + return false + } + + // Check that the client offered version tag is present + clientVersionTag := ddtls.TagTLSClientVersion + tls.VersionName(scenario) + if _, ok := tlsTags[clientVersionTag]; !ok { + t.Logf("Client offered version tag '%s' not found", clientVersionTag) + return false + } + + // Optionally, check for multiple offered versions (e.g., for TLS 1.3) + if scenario == tls.VersionTLS13 { + expectedClientVersions := []string{ + ddtls.TagTLSClientVersion + tls.VersionName(tls.VersionTLS12), + ddtls.TagTLSClientVersion + tls.VersionName(tls.VersionTLS13), + } + for _, tag := range expectedClientVersions { + if _, ok := tlsTags[tag]; !ok { + t.Logf("Expected client offered version tag '%s' not found", tag) + return false + } + } + } + + return true + } + } + return false +} + func skipOnEbpflessNotSupported(t *testing.T, cfg *config.Config) { if cfg.EnableEbpfless { t.Skip("not supported on ebpf-less") From 502f0b68490d0257ca1c0b559a30badcf70aca95 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 2 Dec 2024 21:48:14 -0500 Subject: [PATCH 29/53] fix existing usm classification --- .../ebpf/c/protocols/classification/protocol-classification.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index c0bc2a8f0b23db..d2127f8e7adfe2 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -218,7 +218,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } tls_info_t* tls_info = get_tls_enhanced_tags(&usm_ctx->tuple); if (!tls_info) { - return; + goto next_program; } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); if (!is_tls_handshake_client_hello(skb, &usm_ctx->tls_header, offset)) { @@ -239,7 +239,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } tls_info_t* tls_info = get_tls_enhanced_tags(&usm_ctx->tuple); if (!tls_info) { - return; + goto next_program; } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); if (!is_tls_handshake_server_hello(skb, &usm_ctx->tls_header, offset)) { From a424632ddb45d3dce2fff00375b60e7cd5d242e0 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 2 Dec 2024 22:35:15 -0500 Subject: [PATCH 30/53] reduce max extensions --- .../ebpf/c/protocols/classification/protocol-classification.h | 1 + pkg/network/ebpf/c/protocols/tls/tls.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index d2127f8e7adfe2..9edf14ac164bac 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -254,6 +254,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha return; } mark_as_fully_classified(protocol_stack); + usm_ctx->tls_header = (tls_record_header_t){0}; return; next_program: diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 020e3d8215e2cf..cbc3c3b6c816a9 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,7 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 10 +#define MAX_EXTENSIONS 8 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ From 1506f8769ebf459d787b1c221eb7741745e09772 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 2 Dec 2024 23:19:32 -0500 Subject: [PATCH 31/53] use go tls library for constants --- pkg/network/ebpf/c/protocols/tls/tls.h | 29 ++++++---------- pkg/network/protocols/tls/types.go | 45 +++++++++---------------- pkg/network/protocols/tls/types_test.go | 27 ++++++--------- pkg/network/tracer/tracer_linux_test.go | 1 - 4 files changed, 38 insertions(+), 64 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index cbc3c3b6c816a9..1a448a7abe63b0 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,7 +16,8 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define MAX_EXTENSIONS 8 +#define CLIENT_MAX_EXTENSIONS 4 +#define SERVER_MAX_EXTENSIONS 6 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ @@ -50,23 +51,17 @@ static __always_inline bool is_valid_tls_version(__u16 version) { // set_tls_offered_version sets the bit corresponding to the offered version in the offered_versions field of tls_info static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 version) { switch (version) { - case SSL_VERSION20: - tls_info->offered_versions |= 0x01; // Bit 0 - break; - case SSL_VERSION30: - tls_info->offered_versions |= 0x02; // Bit 1 - break; case TLS_VERSION10: - tls_info->offered_versions |= 0x04; // Bit 2 + tls_info->offered_versions |= 0x01; break; case TLS_VERSION11: - tls_info->offered_versions |= 0x08; // Bit 3 + tls_info->offered_versions |= 0x02; break; case TLS_VERSION12: - tls_info->offered_versions |= 0x10; // Bit 4 + tls_info->offered_versions |= 0x04; break; case TLS_VERSION13: - tls_info->offered_versions |= 0x20; // Bit 5 + tls_info->offered_versions |= 0x08; break; default: break; @@ -204,8 +199,8 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse __u16 extension_length; __u8 sv_list_length; - #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { + #pragma unroll(CLIENT_MAX_EXTENSIONS) + for (int i = 0; i < CLIENT_MAX_EXTENSIONS; i++) { if (offset + 4 > extensions_end) { break; } @@ -271,7 +266,7 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse extensions_parsed++; } - // log_debug("adamk successfully parsed client hello message"); + return 0; } @@ -358,8 +353,8 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse __u16 extension_type; __u16 extension_length; __u16 selected_version; - #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { + #pragma unroll(SERVER_MAX_EXTENSIONS) + for (int i = 0; i < SERVER_MAX_EXTENSIONS; i++) { if (offset + 4 > extensions_end) { break; } @@ -404,8 +399,6 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse } } - // log_debug("adamk successfully parsed server hello message"); - return 0; } diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 4a6c41a2d2a4ea..348bcdbd407fbd 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -6,36 +6,25 @@ // Package tls contains definitions and methods related to tags parsed from the TLS handshake package tls -import "fmt" - -// TLS and SSL version constants -const ( - SSLVersion20 uint16 = 0x0200 - SSLVersion30 uint16 = 0x0300 - TLSVersion10 uint16 = 0x0301 - TLSVersion11 uint16 = 0x0302 - TLSVersion12 uint16 = 0x0303 - TLSVersion13 uint16 = 0x0304 +import ( + "crypto/tls" + "fmt" ) -// Bitmask constants for Offered_versions +// Bitmask constants for Offered_versions matching kernelspace definitions const ( - OfferedSSLVersion20 uint8 = 0x01 // Bit 0 - OfferedSSLVersion30 uint8 = 0x02 // Bit 1 - OfferedTLSVersion10 uint8 = 0x04 // Bit 2 - OfferedTLSVersion11 uint8 = 0x08 // Bit 3 - OfferedTLSVersion12 uint8 = 0x10 // Bit 4 - OfferedTLSVersion13 uint8 = 0x20 // Bit 5 + OfferedTLSVersion10 uint8 = 0x01 + OfferedTLSVersion11 uint8 = 0x02 + OfferedTLSVersion12 uint8 = 0x04 + OfferedTLSVersion13 uint8 = 0x08 ) // mapping of version constants to their string representations var tlsVersionNames = map[uint16]string{ - SSLVersion20: "SSL 2.0", - SSLVersion30: "SSL 3.0", - TLSVersion10: "TLS 1.0", - TLSVersion11: "TLS 1.1", - TLSVersion12: "TLS 1.2", - TLSVersion13: "TLS 1.3", + tls.VersionTLS10: "TLS 1.0", + tls.VersionTLS11: "TLS 1.1", + tls.VersionTLS12: "TLS 1.2", + tls.VersionTLS13: "TLS 1.3", } // Mapping of offered version bitmasks to version constants @@ -43,12 +32,10 @@ var offeredVersionBitmask = []struct { bitMask uint8 version uint16 }{ - {OfferedSSLVersion20, SSLVersion20}, - {OfferedSSLVersion30, SSLVersion30}, - {OfferedTLSVersion10, TLSVersion10}, - {OfferedTLSVersion11, TLSVersion11}, - {OfferedTLSVersion12, TLSVersion12}, - {OfferedTLSVersion13, TLSVersion13}, + {OfferedTLSVersion10, tls.VersionTLS10}, + {OfferedTLSVersion11, tls.VersionTLS11}, + {OfferedTLSVersion12, tls.VersionTLS12}, + {OfferedTLSVersion13, tls.VersionTLS13}, } // Constants for tag keys diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go index 5b5f6258ec88fd..911293f3a32c96 100644 --- a/pkg/network/protocols/tls/types_test.go +++ b/pkg/network/protocols/tls/types_test.go @@ -6,6 +6,7 @@ package tls import ( + "crypto/tls" "fmt" "reflect" "testing" @@ -16,12 +17,10 @@ func TestFormatTLSVersion(t *testing.T) { version uint16 expected string }{ - {SSLVersion20, "SSL 2.0"}, - {SSLVersion30, "SSL 3.0"}, - {TLSVersion10, "TLS 1.0"}, - {TLSVersion11, "TLS 1.1"}, - {TLSVersion12, "TLS 1.2"}, - {TLSVersion13, "TLS 1.3"}, + {tls.VersionTLS10, "TLS 1.0"}, + {tls.VersionTLS11, "TLS 1.1"}, + {tls.VersionTLS12, "TLS 1.2"}, + {tls.VersionTLS13, "TLS 1.3"}, {0xFFFF, ""}, // Unknown version {0x0000, ""}, // Zero value {0x0305, ""}, // Version just above known versions @@ -44,15 +43,13 @@ func TestParseOfferedVersions(t *testing.T) { expected []string }{ {0x00, []string{}}, // No versions offered - {OfferedSSLVersion20, []string{"SSL 2.0"}}, - {OfferedSSLVersion30, []string{"SSL 3.0"}}, {OfferedTLSVersion10, []string{"TLS 1.0"}}, {OfferedTLSVersion11, []string{"TLS 1.1"}}, {OfferedTLSVersion12, []string{"TLS 1.2"}}, {OfferedTLSVersion13, []string{"TLS 1.3"}}, {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"TLS 1.0", "TLS 1.2"}}, - {OfferedSSLVersion30 | OfferedTLSVersion11 | OfferedTLSVersion13, []string{"SSL 3.0", "TLS 1.1", "TLS 1.3"}}, - {0xFF, []string{"SSL 2.0", "SSL 3.0", "TLS 1.0", "TLS 1.1", "TLS 1.2", "TLS 1.3"}}, // All bits set + {OfferedTLSVersion11 | OfferedTLSVersion13, []string{"TLS 1.1", "TLS 1.3"}}, + {0xFF, []string{"TLS 1.0", "TLS 1.1", "TLS 1.2", "TLS 1.3"}}, // All bits set {0x40, []string{}}, // Undefined bit set {0x80, []string{}}, // Undefined bit set } @@ -81,7 +78,7 @@ func TestGetTLSDynamicTags(t *testing.T) { { name: "All_Fields_Populated", tlsTags: &Tags{ - ChosenVersion: TLSVersion12, + ChosenVersion: tls.VersionTLS12, CipherSuite: 0x009C, OfferedVersions: OfferedTLSVersion11 | OfferedTLSVersion12, }, @@ -107,7 +104,7 @@ func TestGetTLSDynamicTags(t *testing.T) { { name: "No_Offered_Versions", tlsTags: &Tags{ - ChosenVersion: TLSVersion13, + ChosenVersion: tls.VersionTLS13, CipherSuite: 0x1301, OfferedVersions: 0x00, }, @@ -119,7 +116,7 @@ func TestGetTLSDynamicTags(t *testing.T) { { name: "Zero_Cipher_Suite", tlsTags: &Tags{ - ChosenVersion: TLSVersion10, + ChosenVersion: tls.VersionTLS10, OfferedVersions: OfferedTLSVersion10, }, expected: map[string]struct{}{ @@ -130,7 +127,7 @@ func TestGetTLSDynamicTags(t *testing.T) { { name: "All_Bits_Set_In_Offered_Versions", tlsTags: &Tags{ - ChosenVersion: TLSVersion12, + ChosenVersion: tls.VersionTLS12, CipherSuite: 0xC02F, OfferedVersions: 0xFF, // All bits set }, @@ -141,8 +138,6 @@ func TestGetTLSDynamicTags(t *testing.T) { "tls.client_version:TLS 1.1": {}, "tls.client_version:TLS 1.2": {}, "tls.client_version:TLS 1.3": {}, - "tls.client_version:SSL 2.0": {}, - "tls.client_version:SSL 3.0": {}, }, }, } diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index a0e2281dee46f6..8303d2119e66e8 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2649,7 +2649,6 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo for _, c := range payload.Conns { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) - t.Log("TLS tags: ", tlsTags) // Check that the cipher suite ID tag is present cipherSuiteTagFound := false From c5a48a25bf6aa37f6b0c0cd6b8fc12c45cb2c83f Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 3 Dec 2024 13:26:39 -0500 Subject: [PATCH 32/53] stab at disabling RC --- pkg/network/ebpf/c/protocols/tls/tls.h | 4 +- .../tracer/connection/kprobe/config.go | 6 +- .../tracer/connection/kprobe/tracer.go | 113 ++++++++++-------- pkg/network/tracer/tracer_linux_test.go | 6 +- 4 files changed, 71 insertions(+), 58 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 1a448a7abe63b0..9d251b12ecab2c 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,8 +16,8 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define CLIENT_MAX_EXTENSIONS 4 -#define SERVER_MAX_EXTENSIONS 6 +#define CLIENT_MAX_EXTENSIONS 8 +#define SERVER_MAX_EXTENSIONS 8 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ diff --git a/pkg/network/tracer/connection/kprobe/config.go b/pkg/network/tracer/connection/kprobe/config.go index b7c6bb2ff99865..b961d14a58f4ed 100644 --- a/pkg/network/tracer/connection/kprobe/config.go +++ b/pkg/network/tracer/connection/kprobe/config.go @@ -58,13 +58,15 @@ func enabledProbes(c *config.Config, runtimeTracer, coreTracer bool) (map[probes if c.CollectTCPv4Conns || c.CollectTCPv6Conns { if ClassificationSupported(c) { enableProbe(enabled, probes.ProtocolClassifierEntrySocketFilter) - enableProbe(enabled, probes.ProtocolClassifierTLSClientSocketFilter) - enableProbe(enabled, probes.ProtocolClassifierTLSServerSocketFilter) enableProbe(enabled, probes.ProtocolClassifierQueuesSocketFilter) enableProbe(enabled, probes.ProtocolClassifierDBsSocketFilter) enableProbe(enabled, probes.ProtocolClassifierGRPCSocketFilter) enableProbe(enabled, probes.NetDevQueue) enableProbe(enabled, probes.TCPCloseCleanProtocolsReturn) + if !runtimeTracer { + enableProbe(enabled, probes.ProtocolClassifierTLSClientSocketFilter) + enableProbe(enabled, probes.ProtocolClassifierTLSServerSocketFilter) + } } enableProbe(enabled, selectVersionBasedProbe(runtimeTracer, kv, probes.TCPSendMsg, probes.TCPSendMsgPre410, kv410)) enableProbe(enabled, probes.TCPSendMsgReturn) diff --git a/pkg/network/tracer/connection/kprobe/tracer.go b/pkg/network/tracer/connection/kprobe/tracer.go index 56f282b5e20346..469b2ed7a0cfb4 100644 --- a/pkg/network/tracer/connection/kprobe/tracer.go +++ b/pkg/network/tracer/connection/kprobe/tracer.go @@ -44,57 +44,6 @@ var ( // - 2492d3b867043f6880708d095a7a5d65debcfc32 classificationMinimumKernel = kernel.VersionCode(4, 11, 0) - protocolClassificationTailCalls = []manager.TailCallRoute{ - { - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationTLSClient, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierTLSClientSocketFilter, - UID: probeUID, - }, - }, - { - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationTLSServer, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierTLSServerSocketFilter, - UID: probeUID, - }, - }, - { - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationQueues, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierQueuesSocketFilter, - UID: probeUID, - }, - }, - { - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationDBs, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierDBsSocketFilter, - UID: probeUID, - }, - }, - { - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationGRPC, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierGRPCSocketFilter, - UID: probeUID, - }, - }, - { - ProgArrayName: probes.TCPCloseProgsMap, - Key: 0, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.TCPCloseFlushReturn, - UID: probeUID, - }, - }, - } - // these primarily exist for mocking out in tests coreTracerLoader = loadCORETracer rcTracerLoader = loadRuntimeCompiledTracer @@ -209,6 +158,7 @@ func loadTracerFromAsset(buf bytecode.AssetReader, runtimeTracer, coreTracer boo var tailCallsIdentifiersSet map[manager.ProbeIdentificationPair]struct{} if classificationSupported { + protocolClassificationTailCalls := getProtocolClassificationTailCalls(runtimeTracer) tailCallsIdentifiersSet = make(map[manager.ProbeIdentificationPair]struct{}, len(protocolClassificationTailCalls)) for _, tailCall := range protocolClassificationTailCalls { tailCallsIdentifiersSet[tailCall.ProbeIdentificationPair] = struct{}{} @@ -356,3 +306,64 @@ func isCORETracerSupported() error { return errCORETracerNotSupported } + +func getProtocolClassificationTailCalls(runtimeTracer bool) []manager.TailCallRoute { + protocolClassificationTailCalls := []manager.TailCallRoute{} + + if !runtimeTracer { + protocolClassificationTailCalls = append(protocolClassificationTailCalls, + manager.TailCallRoute{ + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSClient, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSClientSocketFilter, + UID: probeUID, + }, + }, + manager.TailCallRoute{ + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSServer, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSServerSocketFilter, + UID: probeUID, + }, + }, + ) + } + + protocolClassificationTailCalls = append(protocolClassificationTailCalls, + manager.TailCallRoute{ + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationQueues, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierQueuesSocketFilter, + UID: probeUID, + }, + }, + manager.TailCallRoute{ + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationDBs, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierDBsSocketFilter, + UID: probeUID, + }, + }, + manager.TailCallRoute{ + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationGRPC, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierGRPCSocketFilter, + UID: probeUID, + }, + }, + manager.TailCallRoute{ + ProgArrayName: probes.TCPCloseProgsMap, + Key: 0, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.TCPCloseFlushReturn, + UID: probeUID, + }, + }, + ) + return protocolClassificationTailCalls +} diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 8303d2119e66e8..b5c7fa8edf1738 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2565,13 +2565,15 @@ func setupDropTrafficRule(tb testing.TB) (ns string) { func (s *TracerSuite) TestTLSClassification() { t := s.T() + if ebpftest.GetBuildMode() == ebpftest.RuntimeCompiled { + t.Skip("Skipping test on unsupported build mode: ", ebpftest.GetBuildMode()) + } cfg := testConfig() if !kprobe.ClassificationSupported(cfg) { t.Skip("TLS classification platform not supported") } port, err := tracertestutil.GetFreePort() - //port := uint16(44957) require.NoError(t, err) portAsString := strconv.Itoa(int(port)) @@ -2584,7 +2586,6 @@ func (s *TracerSuite) TestTLSClassification() { } tests := make([]tlsTest, 0) for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { - //for _, scenario := range []uint16{tls.VersionTLS12} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), @@ -2677,7 +2678,6 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo return false } - // Optionally, check for multiple offered versions (e.g., for TLS 1.3) if scenario == tls.VersionTLS13 { expectedClientVersions := []string{ ddtls.TagTLSClientVersion + tls.VersionName(tls.VersionTLS12), From 064cc825d09e27f0930dfd2dd915da9dbda2fac1 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 3 Dec 2024 15:53:47 -0500 Subject: [PATCH 33/53] Revert "stab at disabling RC" This reverts commit c5a48a25bf6aa37f6b0c0cd6b8fc12c45cb2c83f. --- pkg/network/ebpf/c/protocols/tls/tls.h | 4 +- .../tracer/connection/kprobe/config.go | 6 +- .../tracer/connection/kprobe/tracer.go | 113 ++++++++---------- pkg/network/tracer/tracer_linux_test.go | 6 +- 4 files changed, 58 insertions(+), 71 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 9d251b12ecab2c..1a448a7abe63b0 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,8 +16,8 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define CLIENT_MAX_EXTENSIONS 8 -#define SERVER_MAX_EXTENSIONS 8 +#define CLIENT_MAX_EXTENSIONS 4 +#define SERVER_MAX_EXTENSIONS 6 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ diff --git a/pkg/network/tracer/connection/kprobe/config.go b/pkg/network/tracer/connection/kprobe/config.go index b961d14a58f4ed..b7c6bb2ff99865 100644 --- a/pkg/network/tracer/connection/kprobe/config.go +++ b/pkg/network/tracer/connection/kprobe/config.go @@ -58,15 +58,13 @@ func enabledProbes(c *config.Config, runtimeTracer, coreTracer bool) (map[probes if c.CollectTCPv4Conns || c.CollectTCPv6Conns { if ClassificationSupported(c) { enableProbe(enabled, probes.ProtocolClassifierEntrySocketFilter) + enableProbe(enabled, probes.ProtocolClassifierTLSClientSocketFilter) + enableProbe(enabled, probes.ProtocolClassifierTLSServerSocketFilter) enableProbe(enabled, probes.ProtocolClassifierQueuesSocketFilter) enableProbe(enabled, probes.ProtocolClassifierDBsSocketFilter) enableProbe(enabled, probes.ProtocolClassifierGRPCSocketFilter) enableProbe(enabled, probes.NetDevQueue) enableProbe(enabled, probes.TCPCloseCleanProtocolsReturn) - if !runtimeTracer { - enableProbe(enabled, probes.ProtocolClassifierTLSClientSocketFilter) - enableProbe(enabled, probes.ProtocolClassifierTLSServerSocketFilter) - } } enableProbe(enabled, selectVersionBasedProbe(runtimeTracer, kv, probes.TCPSendMsg, probes.TCPSendMsgPre410, kv410)) enableProbe(enabled, probes.TCPSendMsgReturn) diff --git a/pkg/network/tracer/connection/kprobe/tracer.go b/pkg/network/tracer/connection/kprobe/tracer.go index 469b2ed7a0cfb4..56f282b5e20346 100644 --- a/pkg/network/tracer/connection/kprobe/tracer.go +++ b/pkg/network/tracer/connection/kprobe/tracer.go @@ -44,6 +44,57 @@ var ( // - 2492d3b867043f6880708d095a7a5d65debcfc32 classificationMinimumKernel = kernel.VersionCode(4, 11, 0) + protocolClassificationTailCalls = []manager.TailCallRoute{ + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSClient, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSClientSocketFilter, + UID: probeUID, + }, + }, + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationTLSServer, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierTLSServerSocketFilter, + UID: probeUID, + }, + }, + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationQueues, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierQueuesSocketFilter, + UID: probeUID, + }, + }, + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationDBs, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierDBsSocketFilter, + UID: probeUID, + }, + }, + { + ProgArrayName: probes.ClassificationProgsMap, + Key: netebpf.ClassificationGRPC, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.ProtocolClassifierGRPCSocketFilter, + UID: probeUID, + }, + }, + { + ProgArrayName: probes.TCPCloseProgsMap, + Key: 0, + ProbeIdentificationPair: manager.ProbeIdentificationPair{ + EBPFFuncName: probes.TCPCloseFlushReturn, + UID: probeUID, + }, + }, + } + // these primarily exist for mocking out in tests coreTracerLoader = loadCORETracer rcTracerLoader = loadRuntimeCompiledTracer @@ -158,7 +209,6 @@ func loadTracerFromAsset(buf bytecode.AssetReader, runtimeTracer, coreTracer boo var tailCallsIdentifiersSet map[manager.ProbeIdentificationPair]struct{} if classificationSupported { - protocolClassificationTailCalls := getProtocolClassificationTailCalls(runtimeTracer) tailCallsIdentifiersSet = make(map[manager.ProbeIdentificationPair]struct{}, len(protocolClassificationTailCalls)) for _, tailCall := range protocolClassificationTailCalls { tailCallsIdentifiersSet[tailCall.ProbeIdentificationPair] = struct{}{} @@ -306,64 +356,3 @@ func isCORETracerSupported() error { return errCORETracerNotSupported } - -func getProtocolClassificationTailCalls(runtimeTracer bool) []manager.TailCallRoute { - protocolClassificationTailCalls := []manager.TailCallRoute{} - - if !runtimeTracer { - protocolClassificationTailCalls = append(protocolClassificationTailCalls, - manager.TailCallRoute{ - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationTLSClient, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierTLSClientSocketFilter, - UID: probeUID, - }, - }, - manager.TailCallRoute{ - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationTLSServer, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierTLSServerSocketFilter, - UID: probeUID, - }, - }, - ) - } - - protocolClassificationTailCalls = append(protocolClassificationTailCalls, - manager.TailCallRoute{ - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationQueues, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierQueuesSocketFilter, - UID: probeUID, - }, - }, - manager.TailCallRoute{ - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationDBs, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierDBsSocketFilter, - UID: probeUID, - }, - }, - manager.TailCallRoute{ - ProgArrayName: probes.ClassificationProgsMap, - Key: netebpf.ClassificationGRPC, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.ProtocolClassifierGRPCSocketFilter, - UID: probeUID, - }, - }, - manager.TailCallRoute{ - ProgArrayName: probes.TCPCloseProgsMap, - Key: 0, - ProbeIdentificationPair: manager.ProbeIdentificationPair{ - EBPFFuncName: probes.TCPCloseFlushReturn, - UID: probeUID, - }, - }, - ) - return protocolClassificationTailCalls -} diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index b5c7fa8edf1738..8303d2119e66e8 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2565,15 +2565,13 @@ func setupDropTrafficRule(tb testing.TB) (ns string) { func (s *TracerSuite) TestTLSClassification() { t := s.T() - if ebpftest.GetBuildMode() == ebpftest.RuntimeCompiled { - t.Skip("Skipping test on unsupported build mode: ", ebpftest.GetBuildMode()) - } cfg := testConfig() if !kprobe.ClassificationSupported(cfg) { t.Skip("TLS classification platform not supported") } port, err := tracertestutil.GetFreePort() + //port := uint16(44957) require.NoError(t, err) portAsString := strconv.Itoa(int(port)) @@ -2586,6 +2584,7 @@ func (s *TracerSuite) TestTLSClassification() { } tests := make([]tlsTest, 0) for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { + //for _, scenario := range []uint16{tls.VersionTLS12} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), @@ -2678,6 +2677,7 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo return false } + // Optionally, check for multiple offered versions (e.g., for TLS 1.3) if scenario == tls.VersionTLS13 { expectedClientVersions := []string{ ddtls.TagTLSClientVersion + tls.VersionName(tls.VersionTLS12), From 3d15f660a18a9d33deb5aa1589d650c21353f517 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 3 Dec 2024 15:54:42 -0500 Subject: [PATCH 34/53] set max extensions to 1 --- pkg/network/ebpf/c/protocols/tls/tls.h | 4 ++-- pkg/network/tracer/tracer_linux_test.go | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 1a448a7abe63b0..de9fae3f52d933 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,8 +16,8 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define CLIENT_MAX_EXTENSIONS 4 -#define SERVER_MAX_EXTENSIONS 6 +#define CLIENT_MAX_EXTENSIONS 1 +#define SERVER_MAX_EXTENSIONS 1 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 8303d2119e66e8..cebc8f4b573ba8 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2565,13 +2565,13 @@ func setupDropTrafficRule(tb testing.TB) (ns string) { func (s *TracerSuite) TestTLSClassification() { t := s.T() + t.Skip() cfg := testConfig() if !kprobe.ClassificationSupported(cfg) { t.Skip("TLS classification platform not supported") } port, err := tracertestutil.GetFreePort() - //port := uint16(44957) require.NoError(t, err) portAsString := strconv.Itoa(int(port)) @@ -2584,7 +2584,6 @@ func (s *TracerSuite) TestTLSClassification() { } tests := make([]tlsTest, 0) for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { - //for _, scenario := range []uint16{tls.VersionTLS12} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), @@ -2677,7 +2676,6 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo return false } - // Optionally, check for multiple offered versions (e.g., for TLS 1.3) if scenario == tls.VersionTLS13 { expectedClientVersions := []string{ ddtls.TagTLSClientVersion + tls.VersionName(tls.VersionTLS12), From 8e98fa23d45472be52fdd7f31ad1f2f455f22f8e Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Tue, 3 Dec 2024 17:16:33 -0500 Subject: [PATCH 35/53] increase extension count --- pkg/network/ebpf/c/protocols/tls/tls.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index de9fae3f52d933..b7b0d183d26402 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -16,8 +16,8 @@ #define TLS_APPLICATION_DATA 0x17 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define CLIENT_MAX_EXTENSIONS 1 -#define SERVER_MAX_EXTENSIONS 1 +#define CLIENT_MAX_EXTENSIONS 16 +#define SERVER_MAX_EXTENSIONS 16 /* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ From 52a9e5fd7cdbb33bc3f0f3a96694fcdc14bdc335 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 4 Dec 2024 13:57:24 -0500 Subject: [PATCH 36/53] refactor tls packet parsing --- .../classification/shared-tracer-maps.h | 2 +- pkg/network/ebpf/c/protocols/tls/tls.h | 452 +++++++++--------- pkg/network/ebpf/c/tracer/events.h | 2 +- pkg/network/ebpf/c/tracer/tracer.h | 1 - pkg/network/ebpf/kprobe_types_linux.go | 2 +- 5 files changed, 225 insertions(+), 234 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index 50b8162ddc9f73..e673aedb8c9982 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -159,7 +159,7 @@ static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t if (!tags) { conn_tuple_t normalized_tup = *tuple; normalize_tuple(&normalized_tup); - tls_info_t empty_tags = {.reserved = 1}; + tls_info_t empty_tags = {0}; bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags, BPF_ANY); tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); } diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index b7b0d183d26402..8281f53421b1c6 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -1,8 +1,6 @@ #ifndef __TLS_H #define __TLS_H -#include "ktypes.h" -#include "bpf_builtins.h" #include "tracer/tracer.h" #define SSL_VERSION20 0x0200 @@ -12,27 +10,27 @@ #define TLS_VERSION12 0x0303 #define TLS_VERSION13 0x0304 +// TLS Content Types as per RFC 5246 Section 6.2.1 #define TLS_HANDSHAKE 0x16 #define TLS_APPLICATION_DATA 0x17 -#define SUPPORTED_VERSIONS_EXTENSION 0x002B -#define CLIENT_MAX_EXTENSIONS 16 -#define SERVER_MAX_EXTENSIONS 16 +#define TLS_HANDSHAKE_CLIENT_HELLO 0x01 +#define TLS_HANDSHAKE_SERVER_HELLO 0x02 -/* https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer */ +// TLS extensions to parse from the Hello message when searching for the SUPPORTED_VERSIONS_EXTENSION +#define MAX_EXTENSIONS 16 +#define SUPPORTED_VERSIONS_EXTENSION 0x002B +// this corresponds to 16 KB, which is the maximum TLS record size as per the specification #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) -// TLS record layer header structure +// TLS record layer header structure (https://www.rfc-editor.org/rfc/rfc5246#page-19) typedef struct { __u8 content_type; __u16 version; __u16 length; } __attribute__((packed)) tls_record_header_t; -#define TLS_HANDSHAKE_CLIENT_HELLO 0x01 -#define TLS_HANDSHAKE_SERVER_HELLO 0x02 - // is_valid_tls_version checks if the version is a valid TLS version static __always_inline bool is_valid_tls_version(__u16 version) { switch (version) { @@ -110,55 +108,186 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_reco return true; } -// parse_client_hello reads the ClientHello message from the TLS handshake and populates select tags -static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { +static __always_inline int parse_tls_handshake_header(struct __sk_buff *skb, __u64 *offset, __u32 skb_len, __u32 *handshake_length, __u16 *protocol_version) { // Move offset past handshake type (1 byte) - offset += 1; + *offset += 1; // Read handshake length (3 bytes) + if (*offset + 3 > skb_len) + return -1; __u8 handshake_length_bytes[3]; - if (bpf_skb_load_bytes(skb, offset, handshake_length_bytes, 3) < 0) + if (bpf_skb_load_bytes(skb, *offset, handshake_length_bytes, 3) < 0) return -1; - __u32 handshake_length = (handshake_length_bytes[0] << 16) | - (handshake_length_bytes[1] << 8) | - handshake_length_bytes[2]; - offset += 3; + *handshake_length = (handshake_length_bytes[0] << 16) | + (handshake_length_bytes[1] << 8) | + handshake_length_bytes[2]; + *offset += 3; // Ensure we don't read beyond the packet - if (offset + handshake_length > skb_len) + if (*offset + *handshake_length > skb_len) return -1; - // Read client_version (2 bytes) - __u16 client_version; - if (bpf_skb_load_bytes(skb, offset, &client_version, sizeof(client_version)) < 0) + // Read protocol version (2 bytes) + if (*offset + 2 > skb_len) return -1; - client_version = bpf_ntohs(client_version); - offset += 2; - - // Store client_version in tags (in case supported_versions extension is absent) - set_tls_offered_version(tags, client_version); - - if (client_version != TLS_VERSION12) { - // if the version is less than 1.2, there won't be any extensions and we can stop here - return 0; - } + __u16 version; + if (bpf_skb_load_bytes(skb, *offset, &version, sizeof(version)) < 0) + return -1; + *protocol_version = bpf_ntohs(version); + *offset += 2; - // Check if there are extensions if the version is listed as TLS 1.2, as this - // version may actually be 1.3 and the real version is in the extensions + return 0; +} +static __always_inline int skip_random_and_session_id(struct __sk_buff *skb, __u64 *offset, __u32 skb_len) { // Skip Random (32 bytes) - offset += 32; + *offset += 32; // Read Session ID Length (1 byte) + if (*offset + 1 > skb_len) + return -1; __u8 session_id_length; - if (bpf_skb_load_bytes(skb, offset, &session_id_length, sizeof(session_id_length)) < 0) + if (bpf_skb_load_bytes(skb, *offset, &session_id_length, sizeof(session_id_length)) < 0) return -1; - offset += 1; + *offset += 1; // Skip Session ID - offset += session_id_length; + *offset += session_id_length; + + // Ensure we don't read beyond the packet + if (*offset > skb_len) + return -1; + + return 0; +} + +static __always_inline int parse_supported_versions_extension(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, __u32 skb_len, tls_info_t *tags, bool is_client_hello) { + if (is_client_hello) { + // Read list length (1 byte) + if (*offset + 1 > skb_len || *offset + 1 > extensions_end) + return -1; + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, sizeof(sv_list_length)) < 0) + return -1; + *offset += 1; + + // Ensure we don't read beyond the packet + if (*offset + sv_list_length > skb_len || *offset + sv_list_length > extensions_end) + return -1; + + // Parse the list of supported versions + __u8 sv_offset = 0; + __u16 sv_version; + #define MAX_SUPPORTED_VERSIONS 4 + #pragma unroll(MAX_SUPPORTED_VERSIONS) + for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { + if (sv_offset + 1 >= sv_list_length) + break; + if (*offset + 2 > skb_len) + return -1; + + // Load the supported version + if (bpf_skb_load_bytes(skb, *offset, &sv_version, sizeof(sv_version)) < 0) + return -1; + sv_version = bpf_ntohs(sv_version); + *offset += 2; + + // Store the version + set_tls_offered_version(tags, sv_version); + + sv_offset += 2; + } + } else { + // ServerHello + // Extension Length should be 2 + if (*offset + 2 > skb_len) + return -1; + + // Read selected version (2 bytes) + __u16 selected_version; + if (bpf_skb_load_bytes(skb, *offset, &selected_version, sizeof(selected_version)) < 0) + return -1; + selected_version = bpf_ntohs(selected_version); + *offset += 2; + + tags->chosen_version = selected_version; + } + + return 0; +} + + +static __always_inline int parse_tls_extensions(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, __u32 skb_len, tls_info_t *tags, bool is_client_hello) { + __u16 extension_type; + __u16 extension_length; + + #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (*offset + 4 > extensions_end) { + break; + } + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, *offset, &extension_type, sizeof(extension_type)) < 0) + return -1; + extension_type = bpf_ntohs(extension_type); + *offset += 2; + + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, *offset, &extension_length, sizeof(extension_length)) < 0) + return -1; + extension_length = bpf_ntohs(extension_length); + *offset += 2; + + // Ensure we don't read beyond the packet + if (*offset + extension_length > skb_len || *offset + extension_length > extensions_end) + return -1; + + // Check for supported_versions extension + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + int res = parse_supported_versions_extension(skb, offset, extensions_end, skb_len, tags, is_client_hello); + if (res != 0) + return res; + } else { + // Skip other extensions + *offset += extension_length; + } + + // Ensure we don't run past the extensions_end + if (*offset >= extensions_end) + break; + } + + return 0; +} + + +// parse_client_hello reads the ClientHello message from the TLS handshake and populates select tags +static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { + __u32 handshake_length; + __u16 client_version; + int res; + + // Parse the handshake header + res = parse_tls_handshake_header(skb, &offset, skb_len, &handshake_length, &client_version); + if (res != 0) + return res; + + // Store client_version in tags (in case supported_versions extension is absent) + set_tls_offered_version(tags, client_version); + + if (client_version != TLS_VERSION12) { + // If the version is less than 1.2, there won't be any extensions + return 0; + } + + // Skip Random and Session ID + res = skip_random_and_session_id(skb, &offset, skb_len); + if (res != 0) + return res; // Read Cipher Suites Length (2 bytes) + if (offset + 2 > skb_len) + return -1; __u16 cipher_suites_length; if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, sizeof(cipher_suites_length)) < 0) return -1; @@ -169,6 +298,8 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse offset += cipher_suites_length; // Read Compression Methods Length (1 byte) + if (offset + 1 > skb_len) + return -1; __u8 compression_methods_length; if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, sizeof(compression_methods_length)) < 0) return -1; @@ -192,128 +323,41 @@ static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offse if (offset + extensions_length > skb_len) return -1; - // Parse Extensions __u64 extensions_end = offset + extensions_length; - __u8 extensions_parsed = 0; - __u16 extension_type; - __u16 extension_length; - __u8 sv_list_length; - - #pragma unroll(CLIENT_MAX_EXTENSIONS) - for (int i = 0; i < CLIENT_MAX_EXTENSIONS; i++) { - if (offset + 4 > extensions_end) { - break; - } - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) - return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; - - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) - return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; - - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) - return -1; - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (offset + 1 > skb_len) - return -1; - - // Read list length (1 byte) - if (bpf_skb_load_bytes(skb, offset, &sv_list_length, sizeof(sv_list_length)) < 0) - return -1; - offset += 1; - - // Ensure we don't read beyond the packet - if (offset + sv_list_length > skb_len || offset + sv_list_length > extensions_end) - return -1; - - #define MAX_SUPPORTED_VERSIONS 4 - __u8 num_versions = 0; - __u8 i = 0; - __u16 sv_version; - - #pragma unroll(MAX_SUPPORTED_VERSIONS) - for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { - if (i + 1 >= sv_list_length) - break; - if (offset + 2 > skb_len) - return -1; - - // Load the supported version - if (bpf_skb_load_bytes(skb, offset, &sv_version, sizeof(sv_version)) < 0) - return -1; - sv_version = bpf_ntohs(sv_version); - offset += 2; - - // Store the version - set_tls_offered_version(tags, sv_version); - - num_versions++; - i += 2; - } - } else { - // Skip other extensions - offset += extension_length; - } - - extensions_parsed++; - } + // Parse Extensions + res = parse_tls_extensions(skb, &offset, extensions_end, skb_len, tags, true); + if (res != 0) + return res; return 0; } + // parse_server_hello reads the ServerHello message from the TLS handshake and populates select tags static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { - // Move offset past handshake type (1 byte) - offset += 1; - - // Read handshake length (3 bytes) - __u8 handshake_length_bytes[3]; - if (bpf_skb_load_bytes(skb, offset, handshake_length_bytes, 3) < 0) - return -1; - __u32 handshake_length = (handshake_length_bytes[0] << 16) | - (handshake_length_bytes[1] << 8) | - (handshake_length_bytes[2]); - offset += 3; - - // Ensure we don't read beyond the packet - if (offset + handshake_length > skb_len) - return -1; + __u32 handshake_length; + __u16 server_version; + int res; - __u64 handshake_end = offset + handshake_length; + // Parse the handshake header + res = parse_tls_handshake_header(skb, &offset, skb_len, &handshake_length, &server_version); + if (res != 0) + return res; - // Read server_version (2 bytes) - __u16 server_version; - if (bpf_skb_load_bytes(skb, offset, &server_version, sizeof(server_version)) < 0) - return -1; - server_version = bpf_ntohs(server_version); // Set the version here and try to get the "real" version from the extensions // Note: In TLS 1.3, the server_version field is set to 0x0303 (TLS 1.2) // The actual version is embedded in the supported_versions extension tags->chosen_version = server_version; - offset += 2; - - // Skip Random (32 bytes) - offset += 32; - - // Read Session ID Length (1 byte) - __u8 session_id_length; - if (bpf_skb_load_bytes(skb, offset, &session_id_length, sizeof(session_id_length)) < 0) - return -1; - offset += 1; - // Skip Session ID - offset += session_id_length; + // Skip Random and Session ID + res = skip_random_and_session_id(skb, &offset, skb_len); + if (res != 0) + return res; // Read Cipher Suite (2 bytes) + if (offset + 2 > skb_len) + return -1; __u16 cipher_suite; if (bpf_skb_load_bytes(skb, offset, &cipher_suite, sizeof(cipher_suite)) < 0) return -1; @@ -327,111 +371,59 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse tags->cipher_suite = cipher_suite; if (tags->chosen_version != TLS_VERSION12) { - // if the version is less than 1.2, there won't be any extensions and we can stop here + // If the version is less than 1.2, there won't be any extensions return 0; } - // Check if there are extensions if the version is listed as TLS 1.2, as this - // version may actually be 1.3 and the real version is in the extensions - if (offset < handshake_end) { - // Read Extensions Length (2 bytes) - if (offset + 2 > skb_len || offset + 2 > handshake_end) - return -1; - __u16 extensions_length; - if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) - return -1; - extensions_length = bpf_ntohs(extensions_length); - offset += 2; + // Check if there are extensions + if (offset + 2 > skb_len) + return -1; - // Ensure we don't read beyond the packet - if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) - return -1; + // Read Extensions Length (2 bytes) + __u16 extensions_length; + if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) + return -1; + extensions_length = bpf_ntohs(extensions_length); + offset += 2; - // Parse Extensions - __u64 extensions_end = offset + extensions_length; - __u8 extensions_parsed = 0; - __u16 extension_type; - __u16 extension_length; - __u16 selected_version; - #pragma unroll(SERVER_MAX_EXTENSIONS) - for (int i = 0; i < SERVER_MAX_EXTENSIONS; i++) { - if (offset + 4 > extensions_end) { - break; - } - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, sizeof(extension_type)) < 0) - return -1; - extension_type = bpf_ntohs(extension_type); - offset += 2; - - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, sizeof(extension_length)) < 0) - return -1; - extension_length = bpf_ntohs(extension_length); - offset += 2; - - // Ensure we don't read beyond the packet - if (offset + extension_length > skb_len || offset + extension_length > extensions_end) - return -1; - - // Check for supported_versions extension (type 43 or 0x002B) - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Parse supported_versions extension - if (extension_length != 2) - return -1; - - if (offset + 2 > skb_len) - return -1; - - // Read selected version (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &selected_version, sizeof(selected_version)) < 0) - return -1; - selected_version = bpf_ntohs(selected_version); - offset += 2; - - tags->chosen_version = selected_version; - } else { - // Skip other extensions - offset += extension_length; - } - - extensions_parsed++; - } - } + // Ensure we don't read beyond the packet + __u64 handshake_end = offset + handshake_length; + if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) + return -1; - return 0; -} + __u64 extensions_end = offset + extensions_length; -static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { - if (!tls_hdr) { - return false; - } - if (tls_hdr->content_type != TLS_HANDSHAKE) { - return false; - } + // Parse Extensions + res = parse_tls_extensions(skb, &offset, extensions_end, skb_len, tags, false); + if (res != 0) + return res; - // Read handshake type - __u8 handshake_type; - if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) { - return false; - } - return handshake_type == TLS_HANDSHAKE_CLIENT_HELLO; + return 0; } -static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { - if (!tls_hdr) { + +static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u8 expected_handshake_type) { + if (!tls_hdr) return false; - } - if (tls_hdr->content_type != TLS_HANDSHAKE) { + if (tls_hdr->content_type != TLS_HANDSHAKE) return false; - } // Read handshake type + if (offset + 1 > skb->len) + return false; __u8 handshake_type; if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) return false; - return handshake_type == TLS_HANDSHAKE_SERVER_HELLO; + return handshake_type == expected_handshake_type; +} + +static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { + return is_tls_handshake_type(skb, tls_hdr, offset, TLS_HANDSHAKE_CLIENT_HELLO); +} + +static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { + return is_tls_handshake_type(skb, tls_hdr, offset, TLS_HANDSHAKE_SERVER_HELLO); } #endif // __TLS_H diff --git a/pkg/network/ebpf/c/tracer/events.h b/pkg/network/ebpf/c/tracer/events.h index ddecb87b09cd61..85f8ed48e1bab7 100644 --- a/pkg/network/ebpf/c/tracer/events.h +++ b/pkg/network/ebpf/c/tracer/events.h @@ -32,7 +32,7 @@ static __always_inline void clean_protocol_classification(conn_tuple_t *tup) { conn_tuple_t skb_tup = *skb_tup_ptr; delete_protocol_stack(&skb_tup, NULL, FLAG_TCP_CLOSE_DELETION); - bpf_map_delete_elem(&tls_enhanced_tags, &conn_tuple); + bpf_map_delete_elem(&tls_enhanced_tags, &skb_tup); bpf_map_delete_elem(&conn_tuple_to_socket_skb_conn_tuple, &conn_tuple); } diff --git a/pkg/network/ebpf/c/tracer/tracer.h b/pkg/network/ebpf/c/tracer/tracer.h index 0d51937a7ac187..66bd1bb7f65922 100644 --- a/pkg/network/ebpf/c/tracer/tracer.h +++ b/pkg/network/ebpf/c/tracer/tracer.h @@ -33,7 +33,6 @@ typedef struct { __u16 chosen_version; // 2 bytes __u16 cipher_suite; // 2 bytes __u8 offered_versions; // 1 byte (6 bits used) - __u8 reserved; // 1 byte (for alignment or future use) } tls_info_t; typedef struct { diff --git a/pkg/network/ebpf/kprobe_types_linux.go b/pkg/network/ebpf/kprobe_types_linux.go index a5bb92fea6061d..33947e240c7ccc 100644 --- a/pkg/network/ebpf/kprobe_types_linux.go +++ b/pkg/network/ebpf/kprobe_types_linux.go @@ -107,7 +107,7 @@ type TLSTags struct { Chosen_version uint16 Cipher_suite uint16 Offered_versions uint8 - Reserved uint8 + Pad_cgo_0 [1]byte } type _Ctype_struct_sock uint64 From c9079c5d54105cc920641ce7340c98f1cfb06976 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 4 Dec 2024 17:21:20 -0500 Subject: [PATCH 37/53] add release note --- .../add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml diff --git a/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml b/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml new file mode 100644 index 00000000000000..8b9042a8823d42 --- /dev/null +++ b/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml @@ -0,0 +1,13 @@ +# Each section from every release note are combined when the +# CHANGELOG.rst is rendered. So the text needs to be worded so that +# it does not depend on any information only available in another +# section. This may mean repeating some details, but each section +# must be readable independently of the other. +# +# Each section note must be formatted as reStructuredText. +--- +features: + - | + The agent will now tag TLS enhanced metrics like `tls_version` and `tls_cipher`. + This will allow you to filter and aggregate metrics based on the TLS version and cipher used in the connection. + The tags will be added in NPM and USM. From 8bb5cb47680e069ef60f55b30a19ea6b45395fc4 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 4 Dec 2024 17:58:47 -0500 Subject: [PATCH 38/53] add UT for failure scenario --- pkg/network/tracer/tracer_linux_test.go | 54 ++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index fd81ad811e8cb9..cb3150d8c6f849 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2557,7 +2557,7 @@ func (s *TracerSuite) TestTLSClassification() { cfg := testConfig() if !kprobe.ClassificationSupported(cfg) { - t.Skip("TLS classification platform not supported") + t.Skip("protocol classification not supported") } port, err := tracertestutil.GetFreePort() require.NoError(t, err) @@ -2578,7 +2578,6 @@ func (s *TracerSuite) TestTLSClassification() { postTracerSetup: func(t *testing.T) { srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:"+portAsString, func(conn net.Conn) { defer conn.Close() - // Echo back whatever is received _, err := io.Copy(conn, conn) if err != nil { fmt.Printf("Failed to echo data: %v\n", err) @@ -2592,8 +2591,8 @@ func (s *TracerSuite) TestTLSClassification() { MinVersion: scenario, MaxVersion: scenario, InsecureSkipVerify: true, - SessionTicketsDisabled: true, // Disable session tickets - ClientSessionCache: nil, // Disable session cache + SessionTicketsDisabled: true, + ClientSessionCache: nil, } conn, err := net.Dial("tcp", "localhost:"+portAsString) require.NoError(t, err) @@ -2602,7 +2601,6 @@ func (s *TracerSuite) TestTLSClassification() { // Wrap the TCP connection with TLS tlsConn := tls.Client(conn, tlsConfig) - // Perform the TLS handshake require.NoError(t, tlsConn.Handshake()) }, validation: func(t *testing.T, tr *Tracer) { @@ -2612,6 +2610,52 @@ func (s *TracerSuite) TestTLSClassification() { }, }) } + tests = append(tests, tlsTest{ + name: "Invalid-TLS-Handshake", + postTracerSetup: func(t *testing.T) { + // server that accepts connections but does not perform TLS handshake + listener, err := net.Listen("tcp", "localhost:"+portAsString) + require.NoError(t, err) + t.Cleanup(func() { listener.Close() }) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 1024) + _, _ = c.Read(buf) + // Do nothing + }(conn) + } + }() + + // Client connects to the server + conn, err := net.Dial("tcp", "localhost:"+portAsString) + require.NoError(t, err) + defer conn.Close() + + // Send invalid TLS handshake data + _, err = conn.Write([]byte("invalid TLS data")) + require.NoError(t, err) + }, + validation: func(t *testing.T, tr *Tracer) { + // Verify that no TLS tags are set for this connection + require.Eventually(t, func() bool { + payload := getConnections(t, tr) + for _, c := range payload.Conns { + if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) { + t.Log("Unexpected TLS protocol detected for invalid handshake") + return false + } + } + return true + }, 3*time.Second, 100*time.Millisecond) + }, + }) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 1e9c820e789a5babe0e547054bec8fe32eb3b486 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 4 Dec 2024 18:13:35 -0500 Subject: [PATCH 39/53] reduce diff slightly --- .../classification/routing-helpers.h | 22 ++++++++--------- .../ebpf/c/protocols/classification/routing.h | 6 ++--- .../protocols/classification/stack-helpers.h | 24 +++++++++---------- pkg/network/ebpf/c/tracer/tracer.h | 2 +- pkg/network/protocols/tls/types.go | 1 + 5 files changed, 27 insertions(+), 28 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h index c06e6365c048c6..9ce157e4084b11 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h @@ -7,10 +7,10 @@ // has_available_program returns true when there is another program from within // the same protocol layer or false otherwise static __always_inline bool has_available_program(classification_prog_t current_program) { - classification_prog_t next_program = current_program + 1; - if (next_program == __PROG_ENCRYPTION || - next_program == __PROG_APPLICATION || + classification_prog_t next_program = current_program+1; + if (next_program == __PROG_APPLICATION || next_program == __PROG_API || + next_program == __PROG_ENCRYPTION || next_program == CLASSIFICATION_PROG_MAX) { return false; } @@ -19,12 +19,14 @@ static __always_inline bool has_available_program(classification_prog_t current_ } #pragma clang diagnostic push -// The following check is ignored because *currently* there are no API or -// Encryption classification programs registerd. +// The following check is ignored because *currently* there are no API classification programs registerd. // Therefore the enum containing all BPF programs looks like the following: // // typedef enum { // CLASSIFICATION_PROG_UNKNOWN = 0, +// __PROG_ENCRYPTION, +// ENCYPTION_PROG_A +// ... // __PROG_APPLICATION, // APPLICATION_PROG_A // APPLICATION_PROG_B @@ -32,8 +34,6 @@ static __always_inline bool has_available_program(classification_prog_t current_ // ... // __PROG_API, // // No programs here -// __PROG_ENCRYPTION, -// // No programs here // CLASSIFICATION_PROG_MAX, // } classification_prog_t; // @@ -45,11 +45,9 @@ static __always_inline u16 get_current_program_layer(classification_prog_t curre if (current_program > __PROG_ENCRYPTION && current_program < __PROG_APPLICATION) { return LAYER_ENCRYPTION_BIT; } - if (current_program > __PROG_APPLICATION && current_program < __PROG_API) { return LAYER_APPLICATION_BIT; } - if (current_program > __PROG_API && current_program < CLASSIFICATION_PROG_MAX) { return LAYER_API_BIT; } @@ -61,15 +59,15 @@ static __always_inline u16 get_current_program_layer(classification_prog_t curre static __always_inline classification_prog_t next_layer_entrypoint(usm_context_t *usm_ctx) { u16 to_skip = usm_ctx->routing_skip_layers; - if (!(to_skip&LAYER_ENCRYPTION_BIT)) { - return __PROG_ENCRYPTION+1; - } if (!(to_skip&LAYER_APPLICATION_BIT)) { return __PROG_APPLICATION+1; } if (!(to_skip&LAYER_API_BIT)) { return __PROG_API+1; } + if (!(to_skip&LAYER_ENCRYPTION_BIT)) { + return __PROG_ENCRYPTION+1; + } return CLASSIFICATION_PROG_UNKNOWN; } diff --git a/pkg/network/ebpf/c/protocols/classification/routing.h b/pkg/network/ebpf/c/protocols/classification/routing.h index 1529b518684a28..131fa15a918195 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing.h +++ b/pkg/network/ebpf/c/protocols/classification/routing.h @@ -64,15 +64,15 @@ static __always_inline void init_routing_cache(usm_context_t *usm_ctx, protocol_ // 1) If the protocol for that layer is known, // except for encryption as it still needs to be traversed for metadata // 2) If there are no programs registered for that layer - if (stack->flags == FLAG_FULLY_CLASSIFIED || !has_available_program(__PROG_ENCRYPTION)) { - usm_ctx->routing_skip_layers |= LAYER_ENCRYPTION_BIT; - } if (stack->layer_application || !has_available_program(__PROG_APPLICATION)) { usm_ctx->routing_skip_layers |= LAYER_APPLICATION_BIT; } if (stack->layer_api || !has_available_program(__PROG_API)) { usm_ctx->routing_skip_layers |= LAYER_API_BIT; } + if (stack->flags == FLAG_FULLY_CLASSIFIED || !has_available_program(__PROG_ENCRYPTION)) { + usm_ctx->routing_skip_layers |= LAYER_ENCRYPTION_BIT; + } } #endif diff --git a/pkg/network/ebpf/c/protocols/classification/stack-helpers.h b/pkg/network/ebpf/c/protocols/classification/stack-helpers.h index 6b934c0c7b005c..b544de400c93d2 100644 --- a/pkg/network/ebpf/c/protocols/classification/stack-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/stack-helpers.h @@ -12,12 +12,12 @@ static __always_inline protocol_layer_t get_protocol_layer(protocol_t proto) { u16 layer_bit = proto&(LAYER_ENCRYPTION_BIT|LAYER_API_BIT|LAYER_APPLICATION_BIT); switch(layer_bit) { - case LAYER_ENCRYPTION_BIT: - return LAYER_ENCRYPTION; case LAYER_API_BIT: return LAYER_API; case LAYER_APPLICATION_BIT: return LAYER_APPLICATION; + case LAYER_ENCRYPTION_BIT: + return LAYER_ENCRYPTION; } return LAYER_UNKNOWN; @@ -37,15 +37,15 @@ static __always_inline void set_protocol(protocol_stack_t *stack, protocol_t pro // this is the the number of the protocol without the layer bit set __u8 proto_num = (__u8)proto; switch(layer) { - case LAYER_ENCRYPTION: - stack->layer_encryption = proto_num; - return; case LAYER_API: stack->layer_api = proto_num; return; case LAYER_APPLICATION: stack->layer_application = proto_num; return; + case LAYER_ENCRYPTION: + stack->layer_encryption = proto_num; + return; default: return; } @@ -92,10 +92,6 @@ __maybe_unused static __always_inline protocol_t get_protocol_from_stack(protoco __u16 proto_num = 0; __u16 layer_bit = 0; switch(layer) { - case LAYER_ENCRYPTION: - proto_num = stack->layer_encryption; - layer_bit = LAYER_ENCRYPTION_BIT; - break; case LAYER_API: proto_num = stack->layer_api; layer_bit = LAYER_API_BIT; @@ -104,6 +100,10 @@ __maybe_unused static __always_inline protocol_t get_protocol_from_stack(protoco proto_num = stack->layer_application; layer_bit = LAYER_APPLICATION_BIT; break; + case LAYER_ENCRYPTION: + proto_num = stack->layer_encryption; + layer_bit = LAYER_ENCRYPTION_BIT; + break; default: break; } @@ -131,15 +131,15 @@ static __always_inline void merge_protocol_stacks(protocol_stack_t *this, protoc return; } - if (!this->layer_encryption) { - this->layer_encryption = that->layer_encryption; - } if (!this->layer_api) { this->layer_api = that->layer_api; } if (!this->layer_application) { this->layer_application = that->layer_application; } + if (!this->layer_encryption) { + this->layer_encryption = that->layer_encryption; + } this->flags |= that->flags; } diff --git a/pkg/network/ebpf/c/tracer/tracer.h b/pkg/network/ebpf/c/tracer/tracer.h index 66bd1bb7f65922..eba253abd6cb70 100644 --- a/pkg/network/ebpf/c/tracer/tracer.h +++ b/pkg/network/ebpf/c/tracer/tracer.h @@ -32,7 +32,7 @@ typedef enum { typedef struct { __u16 chosen_version; // 2 bytes __u16 cipher_suite; // 2 bytes - __u8 offered_versions; // 1 byte (6 bits used) + __u8 offered_versions; // 1 byte } tls_info_t; typedef struct { diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 348bcdbd407fbd..35f853135eb449 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -20,6 +20,7 @@ const ( ) // mapping of version constants to their string representations +// TODO: use built strings from crypto/tls var tlsVersionNames = map[uint16]string{ tls.VersionTLS10: "TLS 1.0", tls.VersionTLS11: "TLS 1.1", From 4bd350557cc8440f297362123ae2df47858f83ee Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 4 Dec 2024 23:02:39 -0500 Subject: [PATCH 40/53] revert some non-functional changes --- .../ebpf/c/protocols/classification/defs.h | 22 ++++----- .../classification/protocol-classification.h | 3 +- .../classification/routing-helpers.h | 26 +---------- .../protocols/classification/stack-helpers.h | 2 +- pkg/network/ebpf/c/tracer/tracer.h | 6 +-- pkg/network/ebpf/kprobe_types_linux.go | 2 +- pkg/network/protocols/tls/types.go | 9 ++-- pkg/network/protocols/tls/types_test.go | 46 +++++++++---------- pkg/network/tags_linux.go | 2 +- pkg/network/tracer/tracer_linux_test.go | 8 ++-- 10 files changed, 50 insertions(+), 76 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/defs.h b/pkg/network/ebpf/c/protocols/classification/defs.h index 2a690bea9767f0..c91dc427d9e18c 100644 --- a/pkg/network/ebpf/c/protocols/classification/defs.h +++ b/pkg/network/ebpf/c/protocols/classification/defs.h @@ -21,13 +21,13 @@ // The maximum number of protocols per stack layer #define MAX_ENTRIES_PER_LAYER 255 -#define LAYER_ENCRYPTION_BIT (1 << 13) +#define LAYER_API_BIT (1 << 13) #define LAYER_APPLICATION_BIT (1 << 14) -#define LAYER_API_BIT (1 << 15) +#define LAYER_ENCRYPTION_BIT (1 << 15) -#define LAYER_ENCRYPTION_MAX (LAYER_ENCRYPTION_BIT + MAX_ENTRIES_PER_LAYER) -#define LAYER_APPLICATION_MAX (LAYER_APPLICATION_BIT + MAX_ENTRIES_PER_LAYER) #define LAYER_API_MAX (LAYER_API_BIT + MAX_ENTRIES_PER_LAYER) +#define LAYER_APPLICATION_MAX (LAYER_APPLICATION_BIT + MAX_ENTRIES_PER_LAYER) +#define LAYER_ENCRYPTION_MAX (LAYER_ENCRYPTION_BIT + MAX_ENTRIES_PER_LAYER) #define FLAG_FULLY_CLASSIFIED 1 << 0 #define FLAG_USM_ENABLED 1 << 1 @@ -48,11 +48,6 @@ typedef enum { PROTOCOL_UNKNOWN = 0, - __LAYER_ENCRYPTION_MIN = LAYER_ENCRYPTION_BIT, - // Add encryption protocols below (eg. TLS) - PROTOCOL_TLS, - __LAYER_ENCRYPTION_MAX = LAYER_ENCRYPTION_MAX, - __LAYER_API_MIN = LAYER_API_BIT, // Add API protocols here (eg. gRPC) PROTOCOL_GRPC, @@ -70,6 +65,10 @@ typedef enum { PROTOCOL_MYSQL, __LAYER_APPLICATION_MAX = LAYER_APPLICATION_MAX, + __LAYER_ENCRYPTION_MIN = LAYER_ENCRYPTION_BIT, + // Add encryption protocols below (eg. TLS) + PROTOCOL_TLS, + __LAYER_ENCRYPTION_MAX = LAYER_ENCRYPTION_MAX, } __attribute__ ((packed)) protocol_t; // This enum represents all existing protocol layers @@ -81,15 +80,15 @@ typedef enum { // users can call `get_protocol_layer` typedef enum { LAYER_UNKNOWN, - LAYER_ENCRYPTION, LAYER_API, LAYER_APPLICATION, + LAYER_ENCRYPTION, } __attribute__ ((packed)) protocol_layer_t; typedef struct { - __u8 layer_encryption; __u8 layer_api; __u8 layer_application; + __u8 layer_encryption; __u8 flags; } protocol_stack_t; @@ -126,7 +125,6 @@ typedef enum { __PROG_API, // API classification programs go here CLASSIFICATION_GRPC_PROG, - // Add before this value CLASSIFICATION_PROG_MAX, } classification_prog_t; diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index d6c295f7013562..dd697a85606150 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -164,13 +164,12 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct tls_record_header_t tls_hdr = {0}; - // TLS classification if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, &tls_hdr)) { protocol_stack = get_or_create_protocol_stack(&usm_ctx->tuple); if (!protocol_stack) { return; } - // TLS classification + update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); if (tls_hdr.content_type == TLS_APPLICATION_DATA) { // We can't classify TLS encrypted traffic further, so we mark the stack as fully classified diff --git a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h index 9ce157e4084b11..83b27ba19644c2 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h @@ -18,29 +18,7 @@ static __always_inline bool has_available_program(classification_prog_t current_ return true; } -#pragma clang diagnostic push -// The following check is ignored because *currently* there are no API classification programs registerd. -// Therefore the enum containing all BPF programs looks like the following: -// -// typedef enum { -// CLASSIFICATION_PROG_UNKNOWN = 0, -// __PROG_ENCRYPTION, -// ENCYPTION_PROG_A -// ... -// __PROG_APPLICATION, -// APPLICATION_PROG_A -// APPLICATION_PROG_B -// APPLICATION_PROG_C -// ... -// __PROG_API, -// // No programs here -// CLASSIFICATION_PROG_MAX, -// } classification_prog_t; -// -// Which means that the following conditionals will always evaluate to false: -// a) current_program > __PROG_API && current_program < __PROG_ENCRYPTION -// b) current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX -#pragma clang diagnostic ignored "-Wtautological-overlap-compare" +// get_current_program_layer returns the layer bit of the current program static __always_inline u16 get_current_program_layer(classification_prog_t current_program) { if (current_program > __PROG_ENCRYPTION && current_program < __PROG_APPLICATION) { return LAYER_ENCRYPTION_BIT; @@ -54,8 +32,8 @@ static __always_inline u16 get_current_program_layer(classification_prog_t curre return 0; } -#pragma clang diagnostic pop +// next_layer_entrypoint returns the entrypoint of the next layer that should be executed static __always_inline classification_prog_t next_layer_entrypoint(usm_context_t *usm_ctx) { u16 to_skip = usm_ctx->routing_skip_layers; diff --git a/pkg/network/ebpf/c/protocols/classification/stack-helpers.h b/pkg/network/ebpf/c/protocols/classification/stack-helpers.h index b544de400c93d2..5cba1cdf8e30ac 100644 --- a/pkg/network/ebpf/c/protocols/classification/stack-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/stack-helpers.h @@ -9,7 +9,7 @@ // get_protocol_layer(PROTOCOL_HTTP) => LAYER_APPLICATION // get_protocol_layer(PROTOCOL_TLS) => LAYER_ENCRYPTION static __always_inline protocol_layer_t get_protocol_layer(protocol_t proto) { - u16 layer_bit = proto&(LAYER_ENCRYPTION_BIT|LAYER_API_BIT|LAYER_APPLICATION_BIT); + u16 layer_bit = proto&(LAYER_API_BIT|LAYER_APPLICATION_BIT|LAYER_ENCRYPTION_BIT); switch(layer_bit) { case LAYER_API_BIT: diff --git a/pkg/network/ebpf/c/tracer/tracer.h b/pkg/network/ebpf/c/tracer/tracer.h index eba253abd6cb70..8ca45be89fed8f 100644 --- a/pkg/network/ebpf/c/tracer/tracer.h +++ b/pkg/network/ebpf/c/tracer/tracer.h @@ -30,9 +30,9 @@ typedef enum { #define CONN_DIRECTION_MASK 0b11 typedef struct { - __u16 chosen_version; // 2 bytes - __u16 cipher_suite; // 2 bytes - __u8 offered_versions; // 1 byte + __u16 chosen_version; + __u16 cipher_suite; + __u8 offered_versions; } tls_info_t; typedef struct { diff --git a/pkg/network/ebpf/kprobe_types_linux.go b/pkg/network/ebpf/kprobe_types_linux.go index 33947e240c7ccc..9339ac1ec0f281 100644 --- a/pkg/network/ebpf/kprobe_types_linux.go +++ b/pkg/network/ebpf/kprobe_types_linux.go @@ -94,9 +94,9 @@ type BindSyscallArgs struct { Sk uint64 } type ProtocolStack struct { - Encryption uint8 Api uint8 Application uint8 + Encryption uint8 Flags uint8 } type ProtocolStackWrapper struct { diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 35f853135eb449..680c6e8156124d 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -20,12 +20,11 @@ const ( ) // mapping of version constants to their string representations -// TODO: use built strings from crypto/tls var tlsVersionNames = map[uint16]string{ - tls.VersionTLS10: "TLS 1.0", - tls.VersionTLS11: "TLS 1.1", - tls.VersionTLS12: "TLS 1.2", - tls.VersionTLS13: "TLS 1.3", + tls.VersionTLS10: "tls_1.0", + tls.VersionTLS11: "tls_1.1", + tls.VersionTLS12: "tls_1.2", + tls.VersionTLS13: "tls_1.3", } // Mapping of offered version bitmasks to version constants diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go index 911293f3a32c96..b812bb4edb460d 100644 --- a/pkg/network/protocols/tls/types_test.go +++ b/pkg/network/protocols/tls/types_test.go @@ -17,10 +17,10 @@ func TestFormatTLSVersion(t *testing.T) { version uint16 expected string }{ - {tls.VersionTLS10, "TLS 1.0"}, - {tls.VersionTLS11, "TLS 1.1"}, - {tls.VersionTLS12, "TLS 1.2"}, - {tls.VersionTLS13, "TLS 1.3"}, + {tls.VersionTLS10, "tls_1.0"}, + {tls.VersionTLS11, "tls_1.1"}, + {tls.VersionTLS12, "tls_1.2"}, + {tls.VersionTLS13, "tls_1.3"}, {0xFFFF, ""}, // Unknown version {0x0000, ""}, // Zero value {0x0305, ""}, // Version just above known versions @@ -43,13 +43,13 @@ func TestParseOfferedVersions(t *testing.T) { expected []string }{ {0x00, []string{}}, // No versions offered - {OfferedTLSVersion10, []string{"TLS 1.0"}}, - {OfferedTLSVersion11, []string{"TLS 1.1"}}, - {OfferedTLSVersion12, []string{"TLS 1.2"}}, - {OfferedTLSVersion13, []string{"TLS 1.3"}}, - {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"TLS 1.0", "TLS 1.2"}}, - {OfferedTLSVersion11 | OfferedTLSVersion13, []string{"TLS 1.1", "TLS 1.3"}}, - {0xFF, []string{"TLS 1.0", "TLS 1.1", "TLS 1.2", "TLS 1.3"}}, // All bits set + {OfferedTLSVersion10, []string{"tls_1.0"}}, + {OfferedTLSVersion11, []string{"tls_1.1"}}, + {OfferedTLSVersion12, []string{"tls_1.2"}}, + {OfferedTLSVersion13, []string{"tls_1.3"}}, + {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"tls_1.0", "tls_1.2"}}, + {OfferedTLSVersion11 | OfferedTLSVersion13, []string{"tls_1.1", "tls_1.3"}}, + {0xFF, []string{"tls_1.0", "tls_1.1", "tls_1.2", "tls_1.3"}}, // All bits set {0x40, []string{}}, // Undefined bit set {0x80, []string{}}, // Undefined bit set } @@ -83,10 +83,10 @@ func TestGetTLSDynamicTags(t *testing.T) { OfferedVersions: OfferedTLSVersion11 | OfferedTLSVersion12, }, expected: map[string]struct{}{ - "tls.version:TLS 1.2": {}, + "tls.version:tls_1.2": {}, "tls.cipher_suite_id:0x009C": {}, - "tls.client_version:TLS 1.1": {}, - "tls.client_version:TLS 1.2": {}, + "tls.client_version:tls_1.1": {}, + "tls.client_version:tls_1.2": {}, }, }, { @@ -98,7 +98,7 @@ func TestGetTLSDynamicTags(t *testing.T) { }, expected: map[string]struct{}{ "tls.cipher_suite_id:0x00FF": {}, - "tls.client_version:TLS 1.3": {}, + "tls.client_version:tls_1.3": {}, }, }, { @@ -109,7 +109,7 @@ func TestGetTLSDynamicTags(t *testing.T) { OfferedVersions: 0x00, }, expected: map[string]struct{}{ - "tls.version:TLS 1.3": {}, + "tls.version:tls_1.3": {}, "tls.cipher_suite_id:0x1301": {}, }, }, @@ -120,8 +120,8 @@ func TestGetTLSDynamicTags(t *testing.T) { OfferedVersions: OfferedTLSVersion10, }, expected: map[string]struct{}{ - "tls.version:TLS 1.0": {}, - "tls.client_version:TLS 1.0": {}, + "tls.version:tls_1.0": {}, + "tls.client_version:tls_1.0": {}, }, }, { @@ -132,12 +132,12 @@ func TestGetTLSDynamicTags(t *testing.T) { OfferedVersions: 0xFF, // All bits set }, expected: map[string]struct{}{ - "tls.version:TLS 1.2": {}, + "tls.version:tls_1.2": {}, "tls.cipher_suite_id:0xC02F": {}, - "tls.client_version:TLS 1.0": {}, - "tls.client_version:TLS 1.1": {}, - "tls.client_version:TLS 1.2": {}, - "tls.client_version:TLS 1.3": {}, + "tls.client_version:tls_1.0": {}, + "tls.client_version:tls_1.1": {}, + "tls.client_version:tls_1.2": {}, + "tls.client_version:tls_1.3": {}, }, }, } diff --git a/pkg/network/tags_linux.go b/pkg/network/tags_linux.go index 4c2261a2457684..4f81692cfc31a5 100644 --- a/pkg/network/tags_linux.go +++ b/pkg/network/tags_linux.go @@ -26,7 +26,7 @@ const ( ConnTagNodeJS = http.NodeJS ) -// GetStaticTags return the string list of static tags from network.ConnectionStats.StaticTags +// GetStaticTags return the string list of static tags from network.ConnectionStats.Tags func GetStaticTags(staticTags uint64) (tags []string) { for tag, str := range http.StaticTags { if (staticTags & tag) > 0 { diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index cb3150d8c6f849..8eb0cfc6dc6254 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2695,14 +2695,14 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo } // Check that the negotiated version tag is present - negotiatedVersionTag := ddtls.TagTLSVersion + tls.VersionName(scenario) + negotiatedVersionTag := ddtls.TagTLSVersion + ddtls.FormatTLSVersion(scenario) if _, ok := tlsTags[negotiatedVersionTag]; !ok { t.Logf("Negotiated version tag '%s' not found", negotiatedVersionTag) return false } // Check that the client offered version tag is present - clientVersionTag := ddtls.TagTLSClientVersion + tls.VersionName(scenario) + clientVersionTag := ddtls.TagTLSClientVersion + ddtls.FormatTLSVersion(scenario) if _, ok := tlsTags[clientVersionTag]; !ok { t.Logf("Client offered version tag '%s' not found", clientVersionTag) return false @@ -2710,8 +2710,8 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo if scenario == tls.VersionTLS13 { expectedClientVersions := []string{ - ddtls.TagTLSClientVersion + tls.VersionName(tls.VersionTLS12), - ddtls.TagTLSClientVersion + tls.VersionName(tls.VersionTLS13), + ddtls.TagTLSClientVersion + ddtls.FormatTLSVersion(tls.VersionTLS12), + ddtls.TagTLSClientVersion + ddtls.FormatTLSVersion(tls.VersionTLS13), } for _, tag := range expectedClientVersions { if _, ok := tlsTags[tag]; !ok { From 0e55408eb8a7a60fb529f0a9c77d938d3d869db9 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 5 Dec 2024 12:09:05 -0500 Subject: [PATCH 41/53] update releasenotes --- .../notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml b/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml index 8b9042a8823d42..14ed9c4ee8c35a 100644 --- a/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml +++ b/releasenotes/notes/add-tls-enhanced-tags-6ff09ae7fc0ff7a1.yaml @@ -8,6 +8,6 @@ --- features: - | - The agent will now tag TLS enhanced metrics like `tls_version` and `tls_cipher`. + The Agent will now tag TLS enhanced metrics like `tls_version` and `tls_cipher`. This will allow you to filter and aggregate metrics based on the TLS version and cipher used in the connection. - The tags will be added in NPM and USM. + The tags will be added in CNM and USM. From 300686d05a49d5a8235b17b9610abddcaeaf0cf0 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 9 Dec 2024 17:59:09 -0500 Subject: [PATCH 42/53] use map cleaner for tls tags --- .../classification/protocol-classification.h | 9 +- .../classification/shared-tracer-maps.h | 24 +- pkg/network/ebpf/c/protocols/tls/tls.h | 312 ++++++++++-------- pkg/network/ebpf/c/tracer/tracer.h | 5 + pkg/network/ebpf/kprobe_types.go | 1 + pkg/network/ebpf/kprobe_types_linux.go | 4 + pkg/network/protocols/tls/types.go | 5 +- pkg/network/tracer/connection/ebpf_tracer.go | 37 ++- pkg/network/usm/ebpf_main.go | 4 - 9 files changed, 245 insertions(+), 156 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index dd697a85606150..38062c1fc7d637 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -172,8 +172,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); if (tls_hdr.content_type == TLS_APPLICATION_DATA) { - // We can't classify TLS encrypted traffic further, so we mark the stack as fully classified - mark_as_fully_classified(protocol_stack); + // We can't classify TLS encrypted traffic further, so return early return; } @@ -229,7 +228,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!is_tls_handshake_client_hello(skb, &usm_ctx->tls_header, offset)) { goto next_program; } - if (parse_client_hello(skb, offset, skb->len, tls_info) != 0) { + if (!parse_client_hello(skb, offset, tls_info)) { return; } @@ -250,7 +249,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!is_tls_handshake_server_hello(skb, &usm_ctx->tls_header, offset)) { goto next_program; } - if (parse_server_hello(skb, offset, skb->len, tls_info) != 0) { + if (!parse_server_hello(skb, offset, tls_info)) { return; } @@ -258,8 +257,6 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!protocol_stack) { return; } - mark_as_fully_classified(protocol_stack); - usm_ctx->tls_header = (tls_record_header_t){0}; return; next_program: diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index e673aedb8c9982..d8e1c898f47b73 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -10,7 +10,8 @@ // classification procedures on the same connection BPF_HASH_MAP(connection_protocol, conn_tuple_t, protocol_stack_wrapper_t, 0) -BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_info_t, 0) +// Map to store extra information about TLS connections like version, cipher, etc. +BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_info_wrapper_t, 1) static __always_inline bool is_protocol_classification_supported() { __u64 val = 0; @@ -151,7 +152,12 @@ __maybe_unused static __always_inline void delete_protocol_stack(conn_tuple_t* n static __always_inline tls_info_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { conn_tuple_t normalized_tup = *tuple; normalize_tuple(&normalized_tup); - return bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + tls_info_wrapper_t *wrapper = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + if (!wrapper) { + return NULL; + } + wrapper->updated = bpf_ktime_get_ns(); + return &wrapper->info; } static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { @@ -159,11 +165,17 @@ static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t if (!tags) { conn_tuple_t normalized_tup = *tuple; normalize_tuple(&normalized_tup); - tls_info_t empty_tags = {0}; - bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags, BPF_ANY); - tags = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + tls_info_wrapper_t empty_tags_wrapper = {}; + empty_tags_wrapper.updated = bpf_ktime_get_ns(); + + bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags_wrapper, BPF_ANY); + tls_info_wrapper_t *wrapper_ptr = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + if (!wrapper_ptr) { + return NULL; + } + tags = &wrapper_ptr->info; } - return tags; + return tags; } #endif diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 8281f53421b1c6..e44b4ace317ee8 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -17,6 +17,11 @@ #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 +#define TLS_VERSION10_BIT 0x01 +#define TLS_VERSION11_BIT 0x02 +#define TLS_VERSION12_BIT 0x04 +#define TLS_VERSION13_BIT 0x08 + // TLS extensions to parse from the Hello message when searching for the SUPPORTED_VERSIONS_EXTENSION #define MAX_EXTENSIONS 16 #define SUPPORTED_VERSIONS_EXTENSION 0x002B @@ -24,6 +29,15 @@ // this corresponds to 16 KB, which is the maximum TLS record size as per the specification #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) +// Byte lengths of fields in the TLS handshake header +#define RANDOM_LENGTH 32 +#define TLS_HANDSHAKE_LENGTH 3 +#define PROTOCOL_VERSION_LENGTH 2 +#define SESSION_ID_LENGTH 1 +#define CIPHER_SUITES_LENGTH 2 +#define COMPRESSION_METHODS_LENGTH 1 +#define EXTENSION_TYPE_LENGTH 2 + // TLS record layer header structure (https://www.rfc-editor.org/rfc/rfc5246#page-19) typedef struct { __u8 content_type; @@ -50,16 +64,16 @@ static __always_inline bool is_valid_tls_version(__u16 version) { static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 version) { switch (version) { case TLS_VERSION10: - tls_info->offered_versions |= 0x01; + tls_info->offered_versions |= TLS_VERSION10_BIT; break; case TLS_VERSION11: - tls_info->offered_versions |= 0x02; + tls_info->offered_versions |= TLS_VERSION11_BIT; break; case TLS_VERSION12: - tls_info->offered_versions |= 0x04; + tls_info->offered_versions |= TLS_VERSION12_BIT; break; case TLS_VERSION13: - tls_info->offered_versions |= 0x08; + tls_info->offered_versions |= TLS_VERSION13_BIT; break; default: break; @@ -71,109 +85,122 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 __u32 skb_len = skb->len; // Ensure there's enough space for TLS record header - if (nh_off + sizeof(tls_record_header_t) > skb_len) + if (nh_off + sizeof(tls_record_header_t) > skb_len) { return false; + } // Read TLS record header - if (bpf_skb_load_bytes(skb, nh_off, tls_hdr, sizeof(tls_record_header_t)) < 0) + if (bpf_skb_load_bytes(skb, nh_off, tls_hdr, sizeof(tls_record_header_t)) < 0) { return false; + } // Convert fields to host byte order tls_hdr->version = bpf_ntohs(tls_hdr->version); tls_hdr->length = bpf_ntohs(tls_hdr->length); // Validate version and length - if (!is_valid_tls_version(tls_hdr->version)) + if (!is_valid_tls_version(tls_hdr->version)) { return false; - if (tls_hdr->length > TLS_MAX_PAYLOAD_LENGTH) + } + if (tls_hdr->length > TLS_MAX_PAYLOAD_LENGTH) { return false; + } // Ensure we don't read beyond the packet - if (nh_off + sizeof(tls_record_header_t) + tls_hdr->length > skb_len) - return false; - - return true; + return nh_off + sizeof(tls_record_header_t) + tls_hdr->length <= skb_len; } // is_tls checks if the packet is a TLS packet and reads the TLS record header static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { // Use the helper function to read and validate the TLS record header - if (!read_tls_record_header(skb, nh_off, tls_hdr)) + if (!read_tls_record_header(skb, nh_off, tls_hdr)) { return false; + } // Validate content type - if (tls_hdr->content_type != TLS_HANDSHAKE && tls_hdr->content_type != TLS_APPLICATION_DATA) - return false; - - return true; + return tls_hdr->content_type == TLS_HANDSHAKE || tls_hdr->content_type == TLS_APPLICATION_DATA; } -static __always_inline int parse_tls_handshake_header(struct __sk_buff *skb, __u64 *offset, __u32 skb_len, __u32 *handshake_length, __u16 *protocol_version) { +static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u64 *offset, __u32 *handshake_length, __u16 *protocol_version) { // Move offset past handshake type (1 byte) *offset += 1; + __u32 skb_len = skb->len; // Read handshake length (3 bytes) - if (*offset + 3 > skb_len) - return -1; - __u8 handshake_length_bytes[3]; - if (bpf_skb_load_bytes(skb, *offset, handshake_length_bytes, 3) < 0) - return -1; + if (*offset + TLS_HANDSHAKE_LENGTH > skb_len) { + return false; + } + __u8 handshake_length_bytes[TLS_HANDSHAKE_LENGTH]; + if (bpf_skb_load_bytes(skb, *offset, handshake_length_bytes, TLS_HANDSHAKE_LENGTH) < 0) { + return false; + } *handshake_length = (handshake_length_bytes[0] << 16) | (handshake_length_bytes[1] << 8) | handshake_length_bytes[2]; - *offset += 3; + *offset += TLS_HANDSHAKE_LENGTH; // Ensure we don't read beyond the packet - if (*offset + *handshake_length > skb_len) - return -1; + if (*offset + *handshake_length > skb_len) { + return false; + } // Read protocol version (2 bytes) - if (*offset + 2 > skb_len) - return -1; + if (*offset + PROTOCOL_VERSION_LENGTH > skb_len) { + return false; + } __u16 version; - if (bpf_skb_load_bytes(skb, *offset, &version, sizeof(version)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, *offset, &version, sizeof(version)) < 0) { + return false; + } *protocol_version = bpf_ntohs(version); - *offset += 2; + *offset += PROTOCOL_VERSION_LENGTH; - return 0; + return true; } -static __always_inline int skip_random_and_session_id(struct __sk_buff *skb, __u64 *offset, __u32 skb_len) { +static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __u64 *offset) { // Skip Random (32 bytes) - *offset += 32; + *offset += RANDOM_LENGTH; + __u32 skb_len = skb->len; // Read Session ID Length (1 byte) - if (*offset + 1 > skb_len) - return -1; + if (*offset + SESSION_ID_LENGTH > skb_len) { + return false; + } __u8 session_id_length; - if (bpf_skb_load_bytes(skb, *offset, &session_id_length, sizeof(session_id_length)) < 0) - return -1; - *offset += 1; + if (bpf_skb_load_bytes(skb, *offset, &session_id_length, sizeof(session_id_length)) < 0) { + return false; + } + *offset += SESSION_ID_LENGTH; // Skip Session ID *offset += session_id_length; // Ensure we don't read beyond the packet - if (*offset > skb_len) - return -1; + if (*offset > skb_len) { + return false; + } - return 0; + return true; } -static __always_inline int parse_supported_versions_extension(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, __u32 skb_len, tls_info_t *tags, bool is_client_hello) { +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { + __u32 skb_len = skb->len; if (is_client_hello) { // Read list length (1 byte) - if (*offset + 1 > skb_len || *offset + 1 > extensions_end) - return -1; + if (*offset + 1 > skb_len || *offset + 1 > extensions_end) { + return false; + } __u8 sv_list_length; - if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, sizeof(sv_list_length)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, sizeof(sv_list_length)) < 0) { + return false; + } *offset += 1; // Ensure we don't read beyond the packet - if (*offset + sv_list_length > skb_len || *offset + sv_list_length > extensions_end) - return -1; + if (*offset + sv_list_length > skb_len || *offset + sv_list_length > extensions_end) { + return false; + } // Parse the list of supported versions __u8 sv_offset = 0; @@ -181,14 +208,17 @@ static __always_inline int parse_supported_versions_extension(struct __sk_buff * #define MAX_SUPPORTED_VERSIONS 4 #pragma unroll(MAX_SUPPORTED_VERSIONS) for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { - if (sv_offset + 1 >= sv_list_length) + if (sv_offset + 1 >= sv_list_length) { break; - if (*offset + 2 > skb_len) - return -1; + } + if (*offset + 2 > skb_len) { + return false; + } // Load the supported version - if (bpf_skb_load_bytes(skb, *offset, &sv_version, sizeof(sv_version)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, *offset, &sv_version, sizeof(sv_version)) < 0) { + return false; + } sv_version = bpf_ntohs(sv_version); *offset += 2; @@ -200,24 +230,26 @@ static __always_inline int parse_supported_versions_extension(struct __sk_buff * } else { // ServerHello // Extension Length should be 2 - if (*offset + 2 > skb_len) - return -1; + if (*offset + 2 > skb_len) { + return false; + } // Read selected version (2 bytes) __u16 selected_version; - if (bpf_skb_load_bytes(skb, *offset, &selected_version, sizeof(selected_version)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, *offset, &selected_version, sizeof(selected_version)) < 0) { + return false; + } selected_version = bpf_ntohs(selected_version); *offset += 2; tags->chosen_version = selected_version; } - return 0; + return true; } -static __always_inline int parse_tls_extensions(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, __u32 skb_len, tls_info_t *tags, bool is_client_hello) { +static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { __u16 extension_type; __u16 extension_length; @@ -227,123 +259,131 @@ static __always_inline int parse_tls_extensions(struct __sk_buff *skb, __u64 *of break; } // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, *offset, &extension_type, sizeof(extension_type)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, *offset, &extension_type, sizeof(extension_type)) < 0) { + return false; + } extension_type = bpf_ntohs(extension_type); *offset += 2; // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, *offset, &extension_length, sizeof(extension_length)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, *offset, &extension_length, sizeof(extension_length)) < 0) { + return false; + } extension_length = bpf_ntohs(extension_length); *offset += 2; // Ensure we don't read beyond the packet - if (*offset + extension_length > skb_len || *offset + extension_length > extensions_end) - return -1; + if (*offset + extension_length > skb->len || *offset + extension_length > extensions_end) { + return false; + } // Check for supported_versions extension if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - int res = parse_supported_versions_extension(skb, offset, extensions_end, skb_len, tags, is_client_hello); - if (res != 0) - return res; + if (!parse_supported_versions_extension(skb, offset, extensions_end, tags, is_client_hello)) { + return false; + } } else { // Skip other extensions *offset += extension_length; } // Ensure we don't run past the extensions_end - if (*offset >= extensions_end) + if (*offset >= extensions_end) { break; + } } - return 0; + return true; } // parse_client_hello reads the ClientHello message from the TLS handshake and populates select tags -static __always_inline int parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { +static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offset, tls_info_t *tags) { __u32 handshake_length; __u16 client_version; - int res; + __u32 skb_len = skb->len; // Parse the handshake header - res = parse_tls_handshake_header(skb, &offset, skb_len, &handshake_length, &client_version); - if (res != 0) - return res; + if (!parse_tls_handshake_header(skb, &offset, &handshake_length, &client_version)) { + return false; + } // Store client_version in tags (in case supported_versions extension is absent) set_tls_offered_version(tags, client_version); + // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), + // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. if (client_version != TLS_VERSION12) { - // If the version is less than 1.2, there won't be any extensions - return 0; + return true; } // Skip Random and Session ID - res = skip_random_and_session_id(skb, &offset, skb_len); - if (res != 0) - return res; + if (!skip_random_and_session_id(skb, &offset)) { + return false; + } // Read Cipher Suites Length (2 bytes) - if (offset + 2 > skb_len) - return -1; + if (offset + CIPHER_SUITES_LENGTH > skb_len) { + return false; + } __u16 cipher_suites_length; - if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, sizeof(cipher_suites_length)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, sizeof(cipher_suites_length)) < 0) { + return false; + } cipher_suites_length = bpf_ntohs(cipher_suites_length); - offset += 2; + offset += CIPHER_SUITES_LENGTH; // Skip Cipher Suites offset += cipher_suites_length; // Read Compression Methods Length (1 byte) - if (offset + 1 > skb_len) - return -1; + if (offset + COMPRESSION_METHODS_LENGTH > skb_len) { + return false; + } __u8 compression_methods_length; - if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, sizeof(compression_methods_length)) < 0) - return -1; - offset += 1; + if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, sizeof(compression_methods_length)) < 0) { + return false; + } + offset += COMPRESSION_METHODS_LENGTH; // Skip Compression Methods offset += compression_methods_length; // Check if extensions are present - if (offset + 2 > skb_len) - return -1; + if (offset + 2 > skb_len) { + return false; + } // Read Extensions Length (2 bytes) __u16 extensions_length; - if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) { + return false; + } extensions_length = bpf_ntohs(extensions_length); offset += 2; // Ensure we don't read beyond the packet - if (offset + extensions_length > skb_len) - return -1; + if (offset + extensions_length > skb_len) { + return false; + } __u64 extensions_end = offset + extensions_length; // Parse Extensions - res = parse_tls_extensions(skb, &offset, extensions_end, skb_len, tags, true); - if (res != 0) - return res; - - return 0; + return parse_tls_extensions(skb, &offset, extensions_end, tags, true); } // parse_server_hello reads the ServerHello message from the TLS handshake and populates select tags -static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 skb_len, tls_info_t *tags) { +static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offset, tls_info_t *tags) { __u32 handshake_length; __u16 server_version; - int res; + __u32 skb_len = skb->len; // Parse the handshake header - res = parse_tls_handshake_header(skb, &offset, skb_len, &handshake_length, &server_version); - if (res != 0) - return res; + if (!parse_tls_handshake_header(skb, &offset, &handshake_length, &server_version)) { + return false; + } // Set the version here and try to get the "real" version from the extensions // Note: In TLS 1.3, the server_version field is set to 0x0303 (TLS 1.2) @@ -351,69 +391,75 @@ static __always_inline int parse_server_hello(struct __sk_buff *skb, __u64 offse tags->chosen_version = server_version; // Skip Random and Session ID - res = skip_random_and_session_id(skb, &offset, skb_len); - if (res != 0) - return res; + if (!skip_random_and_session_id(skb, &offset)) { + return false; + } // Read Cipher Suite (2 bytes) - if (offset + 2 > skb_len) - return -1; + if (offset + CIPHER_SUITES_LENGTH > skb_len) { + return false; + } __u16 cipher_suite; - if (bpf_skb_load_bytes(skb, offset, &cipher_suite, sizeof(cipher_suite)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, offset, &cipher_suite, sizeof(cipher_suite)) < 0) { + return false; + } cipher_suite = bpf_ntohs(cipher_suite); - offset += 2; + offset += CIPHER_SUITES_LENGTH; // Skip Compression Method (1 byte) - offset += 1; + offset += COMPRESSION_METHODS_LENGTH; // Store parsed data into tags tags->cipher_suite = cipher_suite; + // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), + // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. if (tags->chosen_version != TLS_VERSION12) { - // If the version is less than 1.2, there won't be any extensions - return 0; + return true; } // Check if there are extensions - if (offset + 2 > skb_len) - return -1; + if (offset + 2 > skb_len) { + return false; + } // Read Extensions Length (2 bytes) __u16 extensions_length; - if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) - return -1; + if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) { + return false; + } extensions_length = bpf_ntohs(extensions_length); offset += 2; // Ensure we don't read beyond the packet __u64 handshake_end = offset + handshake_length; - if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) - return -1; + if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) { + return false; + } __u64 extensions_end = offset + extensions_length; // Parse Extensions - res = parse_tls_extensions(skb, &offset, extensions_end, skb_len, tags, false); - if (res != 0) - return res; - - return 0; + return parse_tls_extensions(skb, &offset, extensions_end, tags, false); } static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u8 expected_handshake_type) { - if (!tls_hdr) + if (!tls_hdr) { return false; - if (tls_hdr->content_type != TLS_HANDSHAKE) + } + if (tls_hdr->content_type != TLS_HANDSHAKE) { return false; + } // Read handshake type - if (offset + 1 > skb->len) + if (offset + 1 > skb->len) { return false; + } __u8 handshake_type; - if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) + if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) { return false; + } return handshake_type == expected_handshake_type; } diff --git a/pkg/network/ebpf/c/tracer/tracer.h b/pkg/network/ebpf/c/tracer/tracer.h index 8ca45be89fed8f..2eb619c7337807 100644 --- a/pkg/network/ebpf/c/tracer/tracer.h +++ b/pkg/network/ebpf/c/tracer/tracer.h @@ -34,6 +34,11 @@ typedef struct { __u16 cipher_suite; __u8 offered_versions; } tls_info_t; +\ +typedef struct { + tls_info_t info; + __u64 updated; +} tls_info_wrapper_t; typedef struct { __u64 sent_bytes; diff --git a/pkg/network/ebpf/kprobe_types.go b/pkg/network/ebpf/kprobe_types.go index 663fdae87de547..ab5f569a9c3ecc 100644 --- a/pkg/network/ebpf/kprobe_types.go +++ b/pkg/network/ebpf/kprobe_types.go @@ -31,6 +31,7 @@ type BindSyscallArgs C.bind_syscall_args_t type ProtocolStack C.protocol_stack_t type ProtocolStackWrapper C.protocol_stack_wrapper_t type TLSTags C.tls_info_t +type TLSTagsWrapper C.tls_info_wrapper_t // udp_recv_sock_t have *sock and *msghdr struct members, we make them opaque here type _Ctype_struct_sock uint64 diff --git a/pkg/network/ebpf/kprobe_types_linux.go b/pkg/network/ebpf/kprobe_types_linux.go index 9339ac1ec0f281..5e413cb1964f19 100644 --- a/pkg/network/ebpf/kprobe_types_linux.go +++ b/pkg/network/ebpf/kprobe_types_linux.go @@ -109,6 +109,10 @@ type TLSTags struct { Offered_versions uint8 Pad_cgo_0 [1]byte } +type TLSTagsWrapper struct { + Info TLSTags + Updated uint64 +} type _Ctype_struct_sock uint64 type _Ctype_struct_msghdr uint64 diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 680c6e8156124d..c05713e479dacb 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -79,10 +79,7 @@ func (t *Tags) String() string { // FormatTLSVersion converts a version uint16 to its string representation func FormatTLSVersion(version uint16) string { - if name, ok := tlsVersionNames[version]; ok { - return name - } - return "" + return tlsVersionNames[version] } // parseOfferedVersions parses the Offered_versions bitmask into a slice of version strings diff --git a/pkg/network/tracer/connection/ebpf_tracer.go b/pkg/network/tracer/connection/ebpf_tracer.go index 033f8088279939..be3776a7b4dc7a 100644 --- a/pkg/network/tracer/connection/ebpf_tracer.go +++ b/pkg/network/tracer/connection/ebpf_tracer.go @@ -47,6 +47,7 @@ const ( ) var tcpOngoingConnectMapTTL = 30 * time.Minute.Nanoseconds() +var tlsTagsMapTTL = 30 * time.Minute.Nanoseconds() var EbpfTracerTelemetry = struct { //nolint:revive // TODO connections telemetry.Gauge @@ -152,6 +153,8 @@ type ebpfTracer struct { // periodically clean the ongoing connection pid map ongoingConnectCleaner *ddebpf.MapCleaner[netebpf.SkpConn, netebpf.PidTs] + // periodically clean the enhanced TLS tags map + TLSTagsCleaner *ddebpf.MapCleaner[netebpf.ConnTuple, netebpf.TLSTagsWrapper] removeTuple *netebpf.ConnTuple @@ -274,7 +277,7 @@ func newEbpfTracer(config *config.Config, _ telemetryComponent.Component) (Trace ch: newCookieHasher(), } - tr.setupMapCleaner(m) + tr.setupMapCleaners(m) tr.conns, err = maps.GetMap[netebpf.ConnTuple, netebpf.ConnStats](m, probes.ConnMap) if err != nil { @@ -353,6 +356,7 @@ func (t *ebpfTracer) Stop() { _ = t.m.Stop(manager.CleanAll) t.closeConsumer.Stop() t.ongoingConnectCleaner.Stop() + t.TLSTagsCleaner.Stop() if t.closeTracer != nil { t.closeTracer() } @@ -696,8 +700,14 @@ func (t *ebpfTracer) getTCPStats(stats *netebpf.TCPStats, tuple *netebpf.ConnTup return t.tcpStats.Lookup(tuple, stats) == nil } -// setupMapCleaner sets up a map cleaner for the tcp_ongoing_connect_pid map -func (t *ebpfTracer) setupMapCleaner(m *manager.Manager) { +// setupMapCleaners sets up the map cleaners for the eBPF maps +func (t *ebpfTracer) setupMapCleaners(m *manager.Manager) { + t.setupOngoingConnectMapCleaner(m) + t.setupTLSTagsMapCleaner(m) +} + +// setupOngoingConnectMapCleaner sets up a map cleaner for the tcp_ongoing_connect_pid map +func (t *ebpfTracer) setupOngoingConnectMapCleaner(m *manager.Manager) { tcpOngoingConnectPidMap, _, err := m.GetMap(probes.TCPOngoingConnectPid) if err != nil { log.Errorf("error getting %v map: %s", probes.TCPOngoingConnectPid, err) @@ -721,6 +731,27 @@ func (t *ebpfTracer) setupMapCleaner(m *manager.Manager) { t.ongoingConnectCleaner = tcpOngoingConnectPidCleaner } +// setupTLSTagsMapCleaner sets up a map cleaner for the tls_enhanced_tags map +func (t *ebpfTracer) setupTLSTagsMapCleaner(m *manager.Manager) { + TLSTagsMap, _, err := m.GetMap(probes.EnhancedTLSTagsMap) + if err != nil { + log.Errorf("error getting %v map: %s", probes.EnhancedTLSTagsMap, err) + return + } + + TLSTagsMapCleaner, err := ddebpf.NewMapCleaner[netebpf.ConnTuple, netebpf.TLSTagsWrapper](TLSTagsMap, 1024, probes.EnhancedTLSTagsMap, "npm_tracer") + if err != nil { + log.Errorf("error creating map cleaner: %s", err) + return + } + TLSTagsMapCleaner.Clean(time.Minute*1, nil, nil, func(now int64, _ netebpf.ConnTuple, val netebpf.TLSTagsWrapper) bool { + ts := int64(val.Updated) + return ts > 0 && now-ts > tlsTagsMapTTL + }) + + t.TLSTagsCleaner = TLSTagsMapCleaner +} + func populateConnStats(stats *network.ConnectionStats, t *netebpf.ConnTuple, s *netebpf.ConnStats, ch *cookieHasher) { *stats = network.ConnectionStats{ConnectionTuple: network.ConnectionTuple{ Pid: t.Pid, diff --git a/pkg/network/usm/ebpf_main.go b/pkg/network/usm/ebpf_main.go index 861508bd2ae089..85f9f7123a593b 100644 --- a/pkg/network/usm/ebpf_main.go +++ b/pkg/network/usm/ebpf_main.go @@ -395,10 +395,6 @@ func (e *ebpfProgram) init(buf bytecode.AssetReader, options manager.Options) er MaxEntries: e.cfg.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries, }, - probes.EnhancedTLSTagsMap: { - MaxEntries: e.cfg.MaxTrackedConnections, - EditorFlag: manager.EditMaxEntries, - }, tupleByPidFDMap: { MaxEntries: e.cfg.MaxTrackedConnections, EditorFlag: manager.EditMaxEntries, From 8ab745ffb6ebecdf679ca557a96c6fab31113033 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 9 Dec 2024 23:17:58 -0500 Subject: [PATCH 43/53] optimize memory allocation pattern for tls tags --- pkg/network/protocols/tls/types.go | 78 ++++++++++++++++--------- pkg/network/protocols/tls/types_test.go | 39 +++---------- pkg/network/tracer/tracer_linux_test.go | 9 +-- 3 files changed, 62 insertions(+), 64 deletions(-) diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index c05713e479dacb..7eed3dcf8fc905 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -11,6 +11,17 @@ import ( "fmt" ) +// Constants for tag keys +const ( + TagTLSVersion = "tls.version:" + TagTLSCipherSuiteID = "tls.cipher_suite_id:" + TagTLSClientVersion = "tls.client_version:" + Version10 = "tls_1.0" + Version11 = "tls_1.1" + Version12 = "tls_1.2" + Version13 = "tls_1.3" +) + // Bitmask constants for Offered_versions matching kernelspace definitions const ( OfferedTLSVersion10 uint8 = 0x01 @@ -19,12 +30,20 @@ const ( OfferedTLSVersion13 uint8 = 0x08 ) -// mapping of version constants to their string representations -var tlsVersionNames = map[uint16]string{ - tls.VersionTLS10: "tls_1.0", - tls.VersionTLS11: "tls_1.1", - tls.VersionTLS12: "tls_1.2", - tls.VersionTLS13: "tls_1.3", +// VersionTags maps TLS versions to tag names for server chosen version (exported for testing) +var VersionTags = map[uint16]string{ + tls.VersionTLS10: TagTLSVersion + Version10, + tls.VersionTLS11: TagTLSVersion + Version11, + tls.VersionTLS12: TagTLSVersion + Version12, + tls.VersionTLS13: TagTLSVersion + Version13, +} + +// ClientVersionTags maps TLS versions to tag names for client offered versions (exported for testing) +var ClientVersionTags = map[uint16]string{ + tls.VersionTLS10: TagTLSClientVersion + Version10, + tls.VersionTLS11: TagTLSClientVersion + Version11, + tls.VersionTLS12: TagTLSClientVersion + Version12, + tls.VersionTLS13: TagTLSClientVersion + Version13, } // Mapping of offered version bitmasks to version constants @@ -38,13 +57,6 @@ var offeredVersionBitmask = []struct { {OfferedTLSVersion13, tls.VersionTLS13}, } -// Constants for tag keys -const ( - TagTLSVersion = "tls.version:" - TagTLSCipherSuiteID = "tls.cipher_suite_id:" - TagTLSClientVersion = "tls.client_version:" -) - // Tags holds the TLS tags. It is used to store the TLS version, cipher suite and offered versions. // We can't use the struct from eBPF as the definition is shared with windows. type Tags struct { @@ -77,17 +89,12 @@ func (t *Tags) String() string { return fmt.Sprintf("ChosenVersion: %d, CipherSuite: %d, OfferedVersions: %d", t.ChosenVersion, t.CipherSuite, t.OfferedVersions) } -// FormatTLSVersion converts a version uint16 to its string representation -func FormatTLSVersion(version uint16) string { - return tlsVersionNames[version] -} - // parseOfferedVersions parses the Offered_versions bitmask into a slice of version strings func parseOfferedVersions(offeredVersions uint8) []string { - versions := []string{} + versions := make([]string, 0, 4) for _, ov := range offeredVersionBitmask { if (offeredVersions & ov.bitMask) != 0 { - if name := tlsVersionNames[ov.version]; name != "" { + if name := ClientVersionTags[ov.version]; name != "" { versions = append(versions, name) } } @@ -95,6 +102,21 @@ func parseOfferedVersions(offeredVersions uint8) []string { return versions } +func hexCipherSuiteTag(cipherSuite uint16) string { + // Preallocate a buffer for "0x" + 4 hex digits = 6 chars + var buf [6]byte + buf[0] = '0' + buf[1] = 'x' + hex := "0123456789ABCDEF" + + buf[2] = hex[(cipherSuite>>12)&0xF] + buf[3] = hex[(cipherSuite>>8)&0xF] + buf[4] = hex[(cipherSuite>>4)&0xF] + buf[5] = hex[cipherSuite&0xF] + + return TagTLSCipherSuiteID + string(buf[:]) +} + // GetTLSDynamicTags generates dynamic tags based on TLS information func GetTLSDynamicTags(tls *Tags) map[string]struct{} { tags := make(map[string]struct{}) @@ -103,18 +125,18 @@ func GetTLSDynamicTags(tls *Tags) map[string]struct{} { } // Server chosen version - if versionName := FormatTLSVersion(tls.ChosenVersion); versionName != "" { - tags[TagTLSVersion+versionName] = struct{}{} - } - - // Cipher suite ID as hex string - if tls.CipherSuite != 0 { - tags[TagTLSCipherSuiteID+fmt.Sprintf("0x%04X", tls.CipherSuite)] = struct{}{} + if tag, ok := VersionTags[tls.ChosenVersion]; ok { + tags[tag] = struct{}{} } // Client offered versions for _, versionName := range parseOfferedVersions(tls.OfferedVersions) { - tags[TagTLSClientVersion+versionName] = struct{}{} + tags[versionName] = struct{}{} + } + + // Cipher suite ID as hex string + if tls.CipherSuite != 0 { + tags[hexCipherSuiteTag(tls.CipherSuite)] = struct{}{} } return tags diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go index b812bb4edb460d..07e3948bbdf717 100644 --- a/pkg/network/protocols/tls/types_test.go +++ b/pkg/network/protocols/tls/types_test.go @@ -12,44 +12,19 @@ import ( "testing" ) -func TestFormatTLSVersion(t *testing.T) { - tests := []struct { - version uint16 - expected string - }{ - {tls.VersionTLS10, "tls_1.0"}, - {tls.VersionTLS11, "tls_1.1"}, - {tls.VersionTLS12, "tls_1.2"}, - {tls.VersionTLS13, "tls_1.3"}, - {0xFFFF, ""}, // Unknown version - {0x0000, ""}, // Zero value - {0x0305, ""}, // Version just above known versions - {0x01FF, ""}, // Random unknown version - } - - for _, test := range tests { - t.Run(fmt.Sprintf("Version_0x%04X", test.version), func(t *testing.T) { - result := FormatTLSVersion(test.version) - if result != test.expected { - t.Errorf("FormatTLSVersion(0x%04X) = %q; want %q", test.version, result, test.expected) - } - }) - } -} - func TestParseOfferedVersions(t *testing.T) { tests := []struct { offeredVersions uint8 expected []string }{ {0x00, []string{}}, // No versions offered - {OfferedTLSVersion10, []string{"tls_1.0"}}, - {OfferedTLSVersion11, []string{"tls_1.1"}}, - {OfferedTLSVersion12, []string{"tls_1.2"}}, - {OfferedTLSVersion13, []string{"tls_1.3"}}, - {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"tls_1.0", "tls_1.2"}}, - {OfferedTLSVersion11 | OfferedTLSVersion13, []string{"tls_1.1", "tls_1.3"}}, - {0xFF, []string{"tls_1.0", "tls_1.1", "tls_1.2", "tls_1.3"}}, // All bits set + {OfferedTLSVersion10, []string{"tls.client_version:tls_1.0"}}, + {OfferedTLSVersion11, []string{"tls.client_version:tls_1.1"}}, + {OfferedTLSVersion12, []string{"tls.client_version:tls_1.2"}}, + {OfferedTLSVersion13, []string{"tls.client_version:tls_1.3"}}, + {OfferedTLSVersion10 | OfferedTLSVersion12, []string{"tls.client_version:tls_1.0", "tls.client_version:tls_1.2"}}, + {OfferedTLSVersion11 | OfferedTLSVersion13, []string{"tls.client_version:tls_1.1", "tls.client_version:tls_1.3"}}, + {0xFF, []string{"tls.client_version:tls_1.0", "tls.client_version:tls_1.1", "tls.client_version:tls_1.2", "tls.client_version:tls_1.3"}}, // All bits set {0x40, []string{}}, // Undefined bit set {0x80, []string{}}, // Undefined bit set } diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 8eb0cfc6dc6254..0de8704b2c35eb 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2695,23 +2695,24 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo } // Check that the negotiated version tag is present - negotiatedVersionTag := ddtls.TagTLSVersion + ddtls.FormatTLSVersion(scenario) + negotiatedVersionTag := ddtls.VersionTags[scenario] if _, ok := tlsTags[negotiatedVersionTag]; !ok { t.Logf("Negotiated version tag '%s' not found", negotiatedVersionTag) return false } // Check that the client offered version tag is present - clientVersionTag := ddtls.TagTLSClientVersion + ddtls.FormatTLSVersion(scenario) + clientVersionTag := ddtls.ClientVersionTags[scenario] if _, ok := tlsTags[clientVersionTag]; !ok { + t.Log(tlsTags) t.Logf("Client offered version tag '%s' not found", clientVersionTag) return false } if scenario == tls.VersionTLS13 { expectedClientVersions := []string{ - ddtls.TagTLSClientVersion + ddtls.FormatTLSVersion(tls.VersionTLS12), - ddtls.TagTLSClientVersion + ddtls.FormatTLSVersion(tls.VersionTLS13), + ddtls.ClientVersionTags[tls.VersionTLS12], + ddtls.ClientVersionTags[tls.VersionTLS13], } for _, tag := range expectedClientVersions { if _, ok := tlsTags[tag]; !ok { From 8f65b8af27801e1fc42c947ab750688d2384d652 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Wed, 11 Dec 2024 10:46:59 -0500 Subject: [PATCH 44/53] improve comments --- .../classification/protocol-classification.h | 1 + pkg/network/ebpf/c/protocols/tls/tls.h | 166 ++++++++---------- pkg/network/protocols/tls/types.go | 2 +- pkg/network/protocols/tls/types_test.go | 2 +- pkg/network/tracer/connection/ebpf_tracer.go | 5 +- .../tracer/connection/kprobe/tracer.go | 4 +- 6 files changed, 86 insertions(+), 94 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 38062c1fc7d637..b6e2397f5e9a62 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -63,6 +63,7 @@ // updates the the protocol stack and adds the current layer to the routing skip list static __always_inline void update_protocol_information(usm_context_t *usm_ctx, protocol_stack_t *stack, protocol_t proto) { set_protocol(stack, proto); + // Mark the current layer as known except for TLS, since there is still metadata to be extracted if (proto != PROTOCOL_TLS) { usm_ctx->routing_skip_layers |= proto; } diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index e44b4ace317ee8..24dff6151f8bdf 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -3,6 +3,7 @@ #include "tracer/tracer.h" +// TLS version constants (SSL versions are deprecated, included for completeness) #define SSL_VERSION20 0x0200 #define SSL_VERSION30 0x0300 #define TLS_VERSION10 0x0301 @@ -10,42 +11,54 @@ #define TLS_VERSION12 0x0303 #define TLS_VERSION13 0x0304 -// TLS Content Types as per RFC 5246 Section 6.2.1 -#define TLS_HANDSHAKE 0x16 -#define TLS_APPLICATION_DATA 0x17 +// TLS Content Types (RFC 5246 Section 6.2.1) +#define TLS_HANDSHAKE 0x16 +#define TLS_APPLICATION_DATA 0x17 +// TLS Handshake Types #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 +// Bitmask constants for offered versions #define TLS_VERSION10_BIT 0x01 #define TLS_VERSION11_BIT 0x02 #define TLS_VERSION12_BIT 0x04 #define TLS_VERSION13_BIT 0x08 -// TLS extensions to parse from the Hello message when searching for the SUPPORTED_VERSIONS_EXTENSION +// Maximum number of extensions to parse when looking for SUPPORTED_VERSIONS_EXTENSION #define MAX_EXTENSIONS 16 #define SUPPORTED_VERSIONS_EXTENSION 0x002B -// this corresponds to 16 KB, which is the maximum TLS record size as per the specification +// Maximum TLS record payload size (16 KB) #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) -// Byte lengths of fields in the TLS handshake header -#define RANDOM_LENGTH 32 -#define TLS_HANDSHAKE_LENGTH 3 -#define PROTOCOL_VERSION_LENGTH 2 -#define SESSION_ID_LENGTH 1 -#define CIPHER_SUITES_LENGTH 2 -#define COMPRESSION_METHODS_LENGTH 1 -#define EXTENSION_TYPE_LENGTH 2 +// Field Lengths +#define TLS_HANDSHAKE_LENGTH 3 // Handshake length is 3 bytes +#define RANDOM_LENGTH 32 // Random field length in Client/Server Hello +#define PROTOCOL_VERSION_LENGTH 2 // Protocol version field is 2 bytes +#define SESSION_ID_LENGTH 1 // Session ID length field is 1 byte +#define CIPHER_SUITES_LENGTH 2 // Cipher Suites length field is 2 bytes +#define COMPRESSION_METHODS_LENGTH 1 // Compression Methods length field is 1 byte +#define EXTENSION_TYPE_LENGTH 2 // Extension Type field is 2 bytes +#define EXTENSION_LENGTH_FIELD 2 // Extension Length field is 2 bytes -// TLS record layer header structure (https://www.rfc-editor.org/rfc/rfc5246#page-19) +// For single-byte fields (list lengths, etc.) +#define SINGLE_BYTE_LENGTH 1 + +// Minimum extension header length = Extension Type (2 bytes) + Extension Length (2 bytes) = 4 bytes +#define MIN_EXTENSION_HEADER_LENGTH (EXTENSION_TYPE_LENGTH + EXTENSION_LENGTH_FIELD) + +// Maximum number of supported versions we unroll for +#define MAX_SUPPORTED_VERSIONS 4 + +// TLS record layer header structure (RFC 5246) typedef struct { __u8 content_type; __u16 version; __u16 length; } __attribute__((packed)) tls_record_header_t; -// is_valid_tls_version checks if the version is a valid TLS version +// Checks if the TLS version is valid static __always_inline bool is_valid_tls_version(__u16 version) { switch (version) { case SSL_VERSION20: @@ -112,7 +125,7 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 // is_tls checks if the packet is a TLS packet and reads the TLS record header static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { - // Use the helper function to read and validate the TLS record header + // Read and validate the TLS record header if (!read_tls_record_header(skb, nh_off, tls_hdr)) { return false; } @@ -121,9 +134,9 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_reco return tls_hdr->content_type == TLS_HANDSHAKE || tls_hdr->content_type == TLS_APPLICATION_DATA; } +// Parses the TLS handshake header to extract handshake_length and protocol_version static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u64 *offset, __u32 *handshake_length, __u16 *protocol_version) { - // Move offset past handshake type (1 byte) - *offset += 1; + *offset += SINGLE_BYTE_LENGTH; // Move past handshake type (1 byte) __u32 skb_len = skb->len; // Read handshake length (3 bytes) @@ -149,7 +162,7 @@ static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __ return false; } __u16 version; - if (bpf_skb_load_bytes(skb, *offset, &version, sizeof(version)) < 0) { + if (bpf_skb_load_bytes(skb, *offset, &version, PROTOCOL_VERSION_LENGTH) < 0) { return false; } *protocol_version = bpf_ntohs(version); @@ -158,6 +171,7 @@ static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __ return true; } +// skip_random_and_session_id Skips the Random (32 bytes) and Session ID from the TLS hello messages by incrementing the offset pointer static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __u64 *offset) { // Skip Random (32 bytes) *offset += RANDOM_LENGTH; @@ -168,7 +182,7 @@ static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __ return false; } __u8 session_id_length; - if (bpf_skb_load_bytes(skb, *offset, &session_id_length, sizeof(session_id_length)) < 0) { + if (bpf_skb_load_bytes(skb, *offset, &session_id_length, SESSION_ID_LENGTH) < 0) { return false; } *offset += SESSION_ID_LENGTH; @@ -177,70 +191,63 @@ static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __ *offset += session_id_length; // Ensure we don't read beyond the packet - if (*offset > skb_len) { - return false; - } - - return true; + return *offset <= skb_len; } +// parse_supported_versions_extension looks for the supported_versions extension in the ClientHello or ServerHello and populates tags static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { __u32 skb_len = skb->len; if (is_client_hello) { - // Read list length (1 byte) - if (*offset + 1 > skb_len || *offset + 1 > extensions_end) { + // Read supported version list length (1 byte) + if (*offset + SINGLE_BYTE_LENGTH > skb_len || *offset + SINGLE_BYTE_LENGTH > extensions_end) { return false; } __u8 sv_list_length; - if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, sizeof(sv_list_length)) < 0) { + if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, SINGLE_BYTE_LENGTH) < 0) { return false; } - *offset += 1; + *offset += SINGLE_BYTE_LENGTH; - // Ensure we don't read beyond the packet if (*offset + sv_list_length > skb_len || *offset + sv_list_length > extensions_end) { return false; } - // Parse the list of supported versions + // Parse the list of supported versions (2 bytes each) __u8 sv_offset = 0; __u16 sv_version; - #define MAX_SUPPORTED_VERSIONS 4 #pragma unroll(MAX_SUPPORTED_VERSIONS) for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { if (sv_offset + 1 >= sv_list_length) { break; } - if (*offset + 2 > skb_len) { + // Each supported version is 2 bytes + if (*offset + PROTOCOL_VERSION_LENGTH > skb_len) { return false; } - // Load the supported version - if (bpf_skb_load_bytes(skb, *offset, &sv_version, sizeof(sv_version)) < 0) { + if (bpf_skb_load_bytes(skb, *offset, &sv_version, PROTOCOL_VERSION_LENGTH) < 0) { return false; } sv_version = bpf_ntohs(sv_version); - *offset += 2; + *offset += PROTOCOL_VERSION_LENGTH; - // Store the version set_tls_offered_version(tags, sv_version); - - sv_offset += 2; + sv_offset += PROTOCOL_VERSION_LENGTH; } } else { // ServerHello - // Extension Length should be 2 - if (*offset + 2 > skb_len) { + // The selected_version field is 2 bytes + if (*offset + PROTOCOL_VERSION_LENGTH > skb_len) { return false; } // Read selected version (2 bytes) __u16 selected_version; - if (bpf_skb_load_bytes(skb, *offset, &selected_version, sizeof(selected_version)) < 0) { + if (bpf_skb_load_bytes(skb, *offset, &selected_version, PROTOCOL_VERSION_LENGTH) < 0) { return false; } selected_version = bpf_ntohs(selected_version); - *offset += 2; + *offset += PROTOCOL_VERSION_LENGTH; tags->chosen_version = selected_version; } @@ -248,36 +255,36 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff return true; } - +// parse_tls_extensions parses TLS extensions and looks for the supported_versions extension static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { + __u32 skb_len = skb->len; __u16 extension_type; __u16 extension_length; #pragma unroll(MAX_EXTENSIONS) for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (*offset + 4 > extensions_end) { + if (*offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { break; } + // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, *offset, &extension_type, sizeof(extension_type)) < 0) { + if (bpf_skb_load_bytes(skb, *offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { return false; } extension_type = bpf_ntohs(extension_type); - *offset += 2; + *offset += EXTENSION_TYPE_LENGTH; // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, *offset, &extension_length, sizeof(extension_length)) < 0) { + if (bpf_skb_load_bytes(skb, *offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { return false; } extension_length = bpf_ntohs(extension_length); - *offset += 2; + *offset += EXTENSION_LENGTH_FIELD; - // Ensure we don't read beyond the packet - if (*offset + extension_length > skb->len || *offset + extension_length > extensions_end) { + if (*offset + extension_length > skb_len || *offset + extension_length > extensions_end) { return false; } - // Check for supported_versions extension if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { if (!parse_supported_versions_extension(skb, offset, extensions_end, tags, is_client_hello)) { return false; @@ -287,7 +294,6 @@ static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *o *offset += extension_length; } - // Ensure we don't run past the extensions_end if (*offset >= extensions_end) { break; } @@ -296,23 +302,19 @@ static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *o return true; } - -// parse_client_hello reads the ClientHello message from the TLS handshake and populates select tags +// parse_client_hello parses the ClientHello message and populates tags static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offset, tls_info_t *tags) { __u32 handshake_length; __u16 client_version; __u32 skb_len = skb->len; - // Parse the handshake header if (!parse_tls_handshake_header(skb, &offset, &handshake_length, &client_version)) { return false; } - // Store client_version in tags (in case supported_versions extension is absent) set_tls_offered_version(tags, client_version); - // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), - // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. + // If the version is less than TLS 1.2, no extensions to parse if (client_version != TLS_VERSION12) { return true; } @@ -327,7 +329,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs return false; } __u16 cipher_suites_length; - if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, sizeof(cipher_suites_length)) < 0) { + if (bpf_skb_load_bytes(skb, offset, &cipher_suites_length, CIPHER_SUITES_LENGTH) < 0) { return false; } cipher_suites_length = bpf_ntohs(cipher_suites_length); @@ -341,7 +343,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs return false; } __u8 compression_methods_length; - if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, sizeof(compression_methods_length)) < 0) { + if (bpf_skb_load_bytes(skb, offset, &compression_methods_length, COMPRESSION_METHODS_LENGTH) < 0) { return false; } offset += COMPRESSION_METHODS_LENGTH; @@ -350,47 +352,39 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs offset += compression_methods_length; // Check if extensions are present - if (offset + 2 > skb_len) { + if (offset + EXTENSION_LENGTH_FIELD > skb_len) { return false; } // Read Extensions Length (2 bytes) __u16 extensions_length; - if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) { + if (bpf_skb_load_bytes(skb, offset, &extensions_length, EXTENSION_LENGTH_FIELD) < 0) { return false; } extensions_length = bpf_ntohs(extensions_length); - offset += 2; + offset += EXTENSION_LENGTH_FIELD; - // Ensure we don't read beyond the packet if (offset + extensions_length > skb_len) { return false; } __u64 extensions_end = offset + extensions_length; - // Parse Extensions return parse_tls_extensions(skb, &offset, extensions_end, tags, true); } - -// parse_server_hello reads the ServerHello message from the TLS handshake and populates select tags +// parse_server_hello parses the ServerHello message and populates tags static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offset, tls_info_t *tags) { __u32 handshake_length; __u16 server_version; __u32 skb_len = skb->len; - // Parse the handshake header if (!parse_tls_handshake_header(skb, &offset, &handshake_length, &server_version)) { return false; } - // Set the version here and try to get the "real" version from the extensions - // Note: In TLS 1.3, the server_version field is set to 0x0303 (TLS 1.2) - // The actual version is embedded in the supported_versions extension tags->chosen_version = server_version; - // Skip Random and Session ID if (!skip_random_and_session_id(skb, &offset)) { return false; } @@ -400,7 +394,7 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs return false; } __u16 cipher_suite; - if (bpf_skb_load_bytes(skb, offset, &cipher_suite, sizeof(cipher_suite)) < 0) { + if (bpf_skb_load_bytes(skb, offset, &cipher_suite, CIPHER_SUITES_LENGTH) < 0) { return false; } cipher_suite = bpf_ntohs(cipher_suite); @@ -409,29 +403,24 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs // Skip Compression Method (1 byte) offset += COMPRESSION_METHODS_LENGTH; - // Store parsed data into tags tags->cipher_suite = cipher_suite; - // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), - // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. if (tags->chosen_version != TLS_VERSION12) { return true; } - // Check if there are extensions - if (offset + 2 > skb_len) { + if (offset + EXTENSION_LENGTH_FIELD > skb_len) { return false; } // Read Extensions Length (2 bytes) __u16 extensions_length; - if (bpf_skb_load_bytes(skb, offset, &extensions_length, sizeof(extensions_length)) < 0) { + if (bpf_skb_load_bytes(skb, offset, &extensions_length, EXTENSION_LENGTH_FIELD) < 0) { return false; } extensions_length = bpf_ntohs(extensions_length); - offset += 2; + offset += EXTENSION_LENGTH_FIELD; - // Ensure we don't read beyond the packet __u64 handshake_end = offset + handshake_length; if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) { return false; @@ -439,11 +428,10 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs __u64 extensions_end = offset + extensions_length; - // Parse Extensions return parse_tls_extensions(skb, &offset, extensions_end, tags, false); } - +// is_tls_handshake_type checks if the handshake type is the expected type (client or server hello) static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u8 expected_handshake_type) { if (!tls_hdr) { return false; @@ -452,22 +440,24 @@ static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_rec return false; } - // Read handshake type - if (offset + 1 > skb->len) { + // The handshake type is a single byte enumerated value + if (offset + SINGLE_BYTE_LENGTH > skb->len) { return false; - } + } __u8 handshake_type; - if (bpf_skb_load_bytes(skb, offset, &handshake_type, sizeof(handshake_type)) < 0) { + if (bpf_skb_load_bytes(skb, offset, &handshake_type, SINGLE_BYTE_LENGTH) < 0) { return false; } return handshake_type == expected_handshake_type; } +// is_tls_handshake_client_hello checks if the packet is a TLS ClientHello message static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { return is_tls_handshake_type(skb, tls_hdr, offset, TLS_HANDSHAKE_CLIENT_HELLO); } +// is_tls_handshake_server_hello checks if the packet is a TLS ServerHello message static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { return is_tls_handshake_type(skb, tls_hdr, offset, TLS_HANDSHAKE_SERVER_HELLO); } diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 7eed3dcf8fc905..3d1974f434dc97 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -121,7 +121,7 @@ func hexCipherSuiteTag(cipherSuite uint16) string { func GetTLSDynamicTags(tls *Tags) map[string]struct{} { tags := make(map[string]struct{}) if tls == nil { - return tags + return nil } // Server chosen version diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go index 07e3948bbdf717..359ff509d6b263 100644 --- a/pkg/network/protocols/tls/types_test.go +++ b/pkg/network/protocols/tls/types_test.go @@ -48,7 +48,7 @@ func TestGetTLSDynamicTags(t *testing.T) { { name: "Nil_TLSTags", tlsTags: nil, - expected: map[string]struct{}{}, + expected: nil, }, { name: "All_Fields_Populated", diff --git a/pkg/network/tracer/connection/ebpf_tracer.go b/pkg/network/tracer/connection/ebpf_tracer.go index be3776a7b4dc7a..9afb9614f90a7b 100644 --- a/pkg/network/tracer/connection/ebpf_tracer.go +++ b/pkg/network/tracer/connection/ebpf_tracer.go @@ -47,7 +47,7 @@ const ( ) var tcpOngoingConnectMapTTL = 30 * time.Minute.Nanoseconds() -var tlsTagsMapTTL = 30 * time.Minute.Nanoseconds() +var tlsTagsMapTTL = 3 * time.Minute.Nanoseconds() var EbpfTracerTelemetry = struct { //nolint:revive // TODO connections telemetry.Gauge @@ -744,7 +744,8 @@ func (t *ebpfTracer) setupTLSTagsMapCleaner(m *manager.Manager) { log.Errorf("error creating map cleaner: %s", err) return } - TLSTagsMapCleaner.Clean(time.Minute*1, nil, nil, func(now int64, _ netebpf.ConnTuple, val netebpf.TLSTagsWrapper) bool { + // slight jitter to avoid all maps being cleaned at the same time + TLSTagsMapCleaner.Clean(time.Second*70, nil, nil, func(now int64, _ netebpf.ConnTuple, val netebpf.TLSTagsWrapper) bool { ts := int64(val.Updated) return ts > 0 && now-ts > tlsTagsMapTTL }) diff --git a/pkg/network/tracer/connection/kprobe/tracer.go b/pkg/network/tracer/connection/kprobe/tracer.go index 56f282b5e20346..2dbfea6629c1d8 100644 --- a/pkg/network/tracer/connection/kprobe/tracer.go +++ b/pkg/network/tracer/connection/kprobe/tracer.go @@ -106,8 +106,8 @@ var ( ) // ClassificationSupported returns true if the current kernel version supports the classification feature. -// The kernel has to be newer than 4.11.0 since we are using bpf_skb_load_bytes (4.5.0+) method to read from the socket -// filter which was added in 4.11, and a tracepoint (4.7.0+) +// The kernel has to be newer than 4.11.0 since we are using bpf_skb_load_bytes (4.5.0+) method which was added to +// socket filters in 4.11.0, and a tracepoint (4.7.0+) func ClassificationSupported(config *config.Config) bool { if !config.ProtocolClassificationEnabled { return false From 58b6912a9ebee5406d255491952f382275cf55c7 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 12 Dec 2024 14:47:04 -0500 Subject: [PATCH 45/53] use skb_info.data_end instead of skb_len, clean up tests, more comments --- .../classification/protocol-classification.h | 12 ++- pkg/network/ebpf/c/protocols/tls/tls.h | 101 +++++++++--------- pkg/network/tracer/testutil/tcp.go | 13 --- pkg/network/tracer/tracer_linux_test.go | 57 ++++++---- .../usm/tests/tracer_usm_linux_test.go | 16 ++- 5 files changed, 108 insertions(+), 91 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index b6e2397f5e9a62..960f01a5375872 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -165,7 +165,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct tls_record_header_t tls_hdr = {0}; - if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, &tls_hdr)) { + if ((app_layer_proto == PROTOCOL_UNKNOWN || app_layer_proto == PROTOCOL_POSTGRES) && is_tls(skb, skb_info.data_off, skb_info.data_end, &tls_hdr)) { protocol_stack = get_or_create_protocol_stack(&usm_ctx->tuple); if (!protocol_stack) { return; @@ -226,10 +226,11 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha goto next_program; } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); - if (!is_tls_handshake_client_hello(skb, &usm_ctx->tls_header, offset)) { + __u32 data_end = usm_ctx->skb_info.data_end; + if (!is_tls_handshake_client_hello(skb, &usm_ctx->tls_header, offset, data_end)) { goto next_program; } - if (!parse_client_hello(skb, offset, tls_info)) { + if (!parse_client_hello(skb, offset, data_end, tls_info)) { return; } @@ -247,10 +248,11 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha goto next_program; } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); - if (!is_tls_handshake_server_hello(skb, &usm_ctx->tls_header, offset)) { + __u32 data_end = usm_ctx->skb_info.data_end; + if (!is_tls_handshake_server_hello(skb, &usm_ctx->tls_header, offset, data_end)) { goto next_program; } - if (!parse_server_hello(skb, offset, tls_info)) { + if (!parse_server_hello(skb, offset, data_end, tls_info)) { return; } diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 24dff6151f8bdf..d52479fa6951b3 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -94,16 +94,14 @@ static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 } // read_tls_record_header reads the TLS record header from the packet -static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { - __u32 skb_len = skb->len; - +static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { // Ensure there's enough space for TLS record header - if (nh_off + sizeof(tls_record_header_t) > skb_len) { + if (header_offset + sizeof(tls_record_header_t) > data_end) { return false; } // Read TLS record header - if (bpf_skb_load_bytes(skb, nh_off, tls_hdr, sizeof(tls_record_header_t)) < 0) { + if (bpf_skb_load_bytes(skb, header_offset, tls_hdr, sizeof(tls_record_header_t)) < 0) { return false; } @@ -120,13 +118,13 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 } // Ensure we don't read beyond the packet - return nh_off + sizeof(tls_record_header_t) + tls_hdr->length <= skb_len; + return header_offset + sizeof(tls_record_header_t) + tls_hdr->length <= data_end; } // is_tls checks if the packet is a TLS packet and reads the TLS record header -static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_record_header_t *tls_hdr) { +static __always_inline bool is_tls(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { // Read and validate the TLS record header - if (!read_tls_record_header(skb, nh_off, tls_hdr)) { + if (!read_tls_record_header(skb, header_offset, data_end, tls_hdr)) { return false; } @@ -135,12 +133,11 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 nh_off, tls_reco } // Parses the TLS handshake header to extract handshake_length and protocol_version -static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u64 *offset, __u32 *handshake_length, __u16 *protocol_version) { +static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u64 *offset, __u32 data_end, __u32 *handshake_length, __u16 *protocol_version) { *offset += SINGLE_BYTE_LENGTH; // Move past handshake type (1 byte) - __u32 skb_len = skb->len; // Read handshake length (3 bytes) - if (*offset + TLS_HANDSHAKE_LENGTH > skb_len) { + if (*offset + TLS_HANDSHAKE_LENGTH > data_end) { return false; } __u8 handshake_length_bytes[TLS_HANDSHAKE_LENGTH]; @@ -153,12 +150,12 @@ static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __ *offset += TLS_HANDSHAKE_LENGTH; // Ensure we don't read beyond the packet - if (*offset + *handshake_length > skb_len) { + if (*offset + *handshake_length > data_end) { return false; } // Read protocol version (2 bytes) - if (*offset + PROTOCOL_VERSION_LENGTH > skb_len) { + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { return false; } __u16 version; @@ -172,13 +169,12 @@ static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __ } // skip_random_and_session_id Skips the Random (32 bytes) and Session ID from the TLS hello messages by incrementing the offset pointer -static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __u64 *offset) { +static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __u64 *offset, __u32 data_end) { // Skip Random (32 bytes) *offset += RANDOM_LENGTH; - __u32 skb_len = skb->len; // Read Session ID Length (1 byte) - if (*offset + SESSION_ID_LENGTH > skb_len) { + if (*offset + SESSION_ID_LENGTH > data_end) { return false; } __u8 session_id_length; @@ -191,15 +187,14 @@ static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __ *offset += session_id_length; // Ensure we don't read beyond the packet - return *offset <= skb_len; + return *offset <= data_end; } // parse_supported_versions_extension looks for the supported_versions extension in the ClientHello or ServerHello and populates tags -static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { - __u32 skb_len = skb->len; +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { if (is_client_hello) { // Read supported version list length (1 byte) - if (*offset + SINGLE_BYTE_LENGTH > skb_len || *offset + SINGLE_BYTE_LENGTH > extensions_end) { + if (*offset + SINGLE_BYTE_LENGTH > data_end || *offset + SINGLE_BYTE_LENGTH > extensions_end) { return false; } __u8 sv_list_length; @@ -208,7 +203,7 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff } *offset += SINGLE_BYTE_LENGTH; - if (*offset + sv_list_length > skb_len || *offset + sv_list_length > extensions_end) { + if (*offset + sv_list_length > data_end || *offset + sv_list_length > extensions_end) { return false; } @@ -221,7 +216,7 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff break; } // Each supported version is 2 bytes - if (*offset + PROTOCOL_VERSION_LENGTH > skb_len) { + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { return false; } @@ -237,7 +232,7 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff } else { // ServerHello // The selected_version field is 2 bytes - if (*offset + PROTOCOL_VERSION_LENGTH > skb_len) { + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { return false; } @@ -256,8 +251,7 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff } // parse_tls_extensions parses TLS extensions and looks for the supported_versions extension -static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *offset, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { - __u32 skb_len = skb->len; +static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { __u16 extension_type; __u16 extension_length; @@ -281,12 +275,12 @@ static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *o extension_length = bpf_ntohs(extension_length); *offset += EXTENSION_LENGTH_FIELD; - if (*offset + extension_length > skb_len || *offset + extension_length > extensions_end) { + if (*offset + extension_length > data_end || *offset + extension_length > extensions_end) { return false; } if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - if (!parse_supported_versions_extension(skb, offset, extensions_end, tags, is_client_hello)) { + if (!parse_supported_versions_extension(skb, offset, data_end, extensions_end, tags, is_client_hello)) { return false; } } else { @@ -303,29 +297,28 @@ static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *o } // parse_client_hello parses the ClientHello message and populates tags -static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offset, tls_info_t *tags) { +static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 data_end, tls_info_t *tags) { __u32 handshake_length; __u16 client_version; - __u32 skb_len = skb->len; - if (!parse_tls_handshake_header(skb, &offset, &handshake_length, &client_version)) { + if (!parse_tls_handshake_header(skb, &offset, data_end, &handshake_length, &client_version)) { return false; } set_tls_offered_version(tags, client_version); - // If the version is less than TLS 1.2, no extensions to parse + // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), + // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. if (client_version != TLS_VERSION12) { return true; } - // Skip Random and Session ID - if (!skip_random_and_session_id(skb, &offset)) { + if (!skip_random_and_session_id(skb, &offset, data_end)) { return false; } // Read Cipher Suites Length (2 bytes) - if (offset + CIPHER_SUITES_LENGTH > skb_len) { + if (offset + CIPHER_SUITES_LENGTH > data_end) { return false; } __u16 cipher_suites_length; @@ -339,7 +332,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs offset += cipher_suites_length; // Read Compression Methods Length (1 byte) - if (offset + COMPRESSION_METHODS_LENGTH > skb_len) { + if (offset + COMPRESSION_METHODS_LENGTH > data_end) { return false; } __u8 compression_methods_length; @@ -352,7 +345,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs offset += compression_methods_length; // Check if extensions are present - if (offset + EXTENSION_LENGTH_FIELD > skb_len) { + if (offset + EXTENSION_LENGTH_FIELD > data_end) { return false; } @@ -364,33 +357,35 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs extensions_length = bpf_ntohs(extensions_length); offset += EXTENSION_LENGTH_FIELD; - if (offset + extensions_length > skb_len) { + if (offset + extensions_length > data_end) { return false; } __u64 extensions_end = offset + extensions_length; - return parse_tls_extensions(skb, &offset, extensions_end, tags, true); + return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, true); } // parse_server_hello parses the ServerHello message and populates tags -static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offset, tls_info_t *tags) { +static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 data_end, tls_info_t *tags) { __u32 handshake_length; __u16 server_version; - __u32 skb_len = skb->len; - if (!parse_tls_handshake_header(skb, &offset, &handshake_length, &server_version)) { + if (!parse_tls_handshake_header(skb, &offset, data_end, &handshake_length, &server_version)) { return false; } + // Set the version here and try to get the "real" version from the extensions if possible + // Note: In TLS 1.3, the server_version field is set to 1.2 + // The actual version is embedded in the supported_versions extension tags->chosen_version = server_version; - if (!skip_random_and_session_id(skb, &offset)) { + if (!skip_random_and_session_id(skb, &offset, data_end)) { return false; } // Read Cipher Suite (2 bytes) - if (offset + CIPHER_SUITES_LENGTH > skb_len) { + if (offset + CIPHER_SUITES_LENGTH > data_end) { return false; } __u16 cipher_suite; @@ -405,11 +400,13 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs tags->cipher_suite = cipher_suite; + // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), + // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. if (tags->chosen_version != TLS_VERSION12) { return true; } - if (offset + EXTENSION_LENGTH_FIELD > skb_len) { + if (offset + EXTENSION_LENGTH_FIELD > data_end) { return false; } @@ -422,17 +419,17 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs offset += EXTENSION_LENGTH_FIELD; __u64 handshake_end = offset + handshake_length; - if (offset + extensions_length > skb_len || offset + extensions_length > handshake_end) { + if (offset + extensions_length > data_end || offset + extensions_length > handshake_end) { return false; } __u64 extensions_end = offset + extensions_length; - return parse_tls_extensions(skb, &offset, extensions_end, tags, false); + return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, false); } // is_tls_handshake_type checks if the handshake type is the expected type (client or server hello) -static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u8 expected_handshake_type) { +static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u32 data_end, __u8 expected_handshake_type) { if (!tls_hdr) { return false; } @@ -441,7 +438,7 @@ static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_rec } // The handshake type is a single byte enumerated value - if (offset + SINGLE_BYTE_LENGTH > skb->len) { + if (offset + SINGLE_BYTE_LENGTH > data_end) { return false; } __u8 handshake_type; @@ -453,13 +450,13 @@ static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_rec } // is_tls_handshake_client_hello checks if the packet is a TLS ClientHello message -static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { - return is_tls_handshake_type(skb, tls_hdr, offset, TLS_HANDSHAKE_CLIENT_HELLO); +static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u32 data_end) { + return is_tls_handshake_type(skb, tls_hdr, offset, data_end, TLS_HANDSHAKE_CLIENT_HELLO); } // is_tls_handshake_server_hello checks if the packet is a TLS ServerHello message -static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset) { - return is_tls_handshake_type(skb, tls_hdr, offset, TLS_HANDSHAKE_SERVER_HELLO); +static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u32 data_end) { + return is_tls_handshake_type(skb, tls_hdr, offset, data_end, TLS_HANDSHAKE_SERVER_HELLO); } #endif // __TLS_H diff --git a/pkg/network/tracer/testutil/tcp.go b/pkg/network/tracer/testutil/tcp.go index 0c4b749368aa1a..9ef69afc27172c 100644 --- a/pkg/network/tracer/testutil/tcp.go +++ b/pkg/network/tracer/testutil/tcp.go @@ -75,16 +75,3 @@ func (t *TCPServer) Shutdown() { t.ln = nil } } - -// GetFreePort returns a free port on localhost -func GetFreePort() (port uint16, err error) { - var a *net.TCPAddr - if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { - var l *net.TCPListener - if l, err = net.ListenTCP("tcp", a); err == nil { - defer l.Close() - return uint16(l.Addr().(*net.TCPAddr).Port), nil - } - } - return -} diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index ef6c9209022f7f..f1923d47e15734 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2599,24 +2599,22 @@ func (s *TracerSuite) TestTLSClassification() { if !kprobe.ClassificationSupported(cfg) { t.Skip("protocol classification not supported") } - port, err := tracertestutil.GetFreePort() - require.NoError(t, err) - portAsString := strconv.Itoa(int(port)) tr := setupTracer(t, cfg) type tlsTest struct { name string - postTracerSetup func(t *testing.T) - validation func(t *testing.T, tr *Tracer) + postTracerSetup func(t *testing.T) (port uint16, scenario uint16) + validation func(t *testing.T, tr *Tracer, port uint16, scenario uint16) } + tests := make([]tlsTest, 0) for _, scenario := range []uint16{tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13} { scenario := scenario tests = append(tests, tlsTest{ name: strings.Replace(tls.VersionName(scenario), " ", "-", 1), - postTracerSetup: func(t *testing.T) { - srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:"+portAsString, func(conn net.Conn) { + postTracerSetup: func(t *testing.T) (uint16, uint16) { + srv := usmtestutil.NewTLSServerWithSpecificVersion("localhost:0", func(conn net.Conn) { defer conn.Close() _, err := io.Copy(conn, conn) if err != nil { @@ -2627,6 +2625,15 @@ func (s *TracerSuite) TestTLSClassification() { done := make(chan struct{}) require.NoError(t, srv.Run(done)) t.Cleanup(func() { close(done) }) + + // Retrieve the actual port assigned to the server + addr := srv.Address() + _, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + portInt, err := strconv.Atoi(portStr) + require.NoError(t, err) + port := uint16(portInt) + tlsConfig := &tls.Config{ MinVersion: scenario, MaxVersion: scenario, @@ -2634,27 +2641,27 @@ func (s *TracerSuite) TestTLSClassification() { SessionTicketsDisabled: true, ClientSessionCache: nil, } - conn, err := net.Dial("tcp", "localhost:"+portAsString) + conn, err := net.Dial("tcp", addr) require.NoError(t, err) defer conn.Close() - // Wrap the TCP connection with TLS tlsConn := tls.Client(conn, tlsConfig) - require.NoError(t, tlsConn.Handshake()) + + return port, scenario }, - validation: func(t *testing.T, tr *Tracer) { + validation: func(t *testing.T, tr *Tracer, port uint16, scenario uint16) { require.Eventuallyf(t, func() bool { return validateTLSTags(t, tr, port, scenario) - }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", portAsString) + }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", port) }, }) } tests = append(tests, tlsTest{ name: "Invalid-TLS-Handshake", - postTracerSetup: func(t *testing.T) { + postTracerSetup: func(t *testing.T) (uint16, uint16) { // server that accepts connections but does not perform TLS handshake - listener, err := net.Listen("tcp", "localhost:"+portAsString) + listener, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) t.Cleanup(func() { listener.Close() }) @@ -2668,21 +2675,33 @@ func (s *TracerSuite) TestTLSClassification() { defer c.Close() buf := make([]byte, 1024) _, _ = c.Read(buf) - // Do nothing + // Do nothing with the data }(conn) } }() + // Retrieve the actual port from the listener address + addr := listener.Addr().String() + _, portStr, err := net.SplitHostPort(addr) + require.NoError(t, err) + portInt, err := strconv.Atoi(portStr) + require.NoError(t, err) + port := uint16(portInt) + // Client connects to the server - conn, err := net.Dial("tcp", "localhost:"+portAsString) + conn, err := net.Dial("tcp", addr) require.NoError(t, err) defer conn.Close() // Send invalid TLS handshake data _, err = conn.Write([]byte("invalid TLS data")) require.NoError(t, err) + + // Since this is invalid TLS, scenario can be set to something irrelevant, e.g., TLS.VersionTLS12 + // or just 0 since the validation doesn't rely on the scenario for this test. + return port, tls.VersionTLS12 }, - validation: func(t *testing.T, tr *Tracer) { + validation: func(t *testing.T, tr *Tracer, port uint16, scenario uint16) { // Verify that no TLS tags are set for this connection require.Eventually(t, func() bool { payload := getConnections(t, tr) @@ -2708,9 +2727,9 @@ func (s *TracerSuite) TestTLSClassification() { tr.RemoveClient(clientID) require.NoError(t, tr.RegisterClient(clientID)) require.NoError(t, tr.Resume(), "enable probes - before post tracer") - tt.postTracerSetup(t) + port, scenario := tt.postTracerSetup(t) require.NoError(t, tr.Pause(), "disable probes - after post tracer") - tt.validation(t, tr) + tt.validation(t, tr, port, scenario) }) } } diff --git a/pkg/network/usm/tests/tracer_usm_linux_test.go b/pkg/network/usm/tests/tracer_usm_linux_test.go index b9a96543e06a35..cb7cf47e30e4a4 100644 --- a/pkg/network/usm/tests/tracer_usm_linux_test.go +++ b/pkg/network/usm/tests/tracer_usm_linux_test.go @@ -286,6 +286,18 @@ func testProtocolConnectionProtocolMapCleanup(t *testing.T, tr *tracer.Tracer, c }) } +func getFreePort() (port uint16, err error) { + var a *net.TCPAddr + if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + defer l.Close() + return uint16(l.Addr().(*net.TCPAddr).Port), nil + } + } + return +} + func (s *USMSuite) TestIgnoreTLSClassificationIfApplicationProtocolWasDetected() { t := s.T() cfg := tracertestutil.Config() @@ -384,7 +396,7 @@ func (s *USMSuite) TestIgnoreTLSClassificationIfApplicationProtocolWasDetected() } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - clientPort, err := tracertestutil.GetFreePort() + clientPort, err := getFreePort() require.NoError(t, err) dialer := &net.Dialer{ LocalAddr: &net.TCPAddr{ @@ -444,7 +456,7 @@ func (s *USMSuite) TestTLSClassification() { t.Skip("TLS classification platform not supported") } - port, err := tracertestutil.GetFreePort() + port, err := getFreePort() require.NoError(t, err) portAsString := strconv.Itoa(int(port)) From 8cb4e43c3257f8aa704358b3f469bed6347fa651 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 12 Dec 2024 15:33:43 -0500 Subject: [PATCH 46/53] store just content type byte instead of tls record header --- .../classification/protocol-classification.h | 10 ++++++---- .../ebpf/c/protocols/classification/usm-context.h | 2 +- pkg/network/ebpf/c/protocols/tls/tls.h | 15 ++++++--------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 960f01a5375872..9122c8bf58ff07 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -170,7 +170,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct if (!protocol_stack) { return; } - + // TLS classification update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); if (tls_hdr.content_type == TLS_APPLICATION_DATA) { // We can't classify TLS encrypted traffic further, so return early @@ -180,7 +180,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct // Parse TLS payload tls_info_t *tags = get_or_create_tls_enhanced_tags(&usm_ctx->tuple); if (tags) { - usm_ctx->tls_header = tls_hdr; + usm_ctx->tls_content_type = tls_hdr.content_type; // The packet is a TLS handshake, so trigger some tail calls // to extract metadata from the payload goto next_program; @@ -227,7 +227,8 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; - if (!is_tls_handshake_client_hello(skb, &usm_ctx->tls_header, offset, data_end)) { + __u8 content_type = usm_ctx->tls_content_type; + if (!is_tls_handshake_client_hello(skb, content_type, offset, data_end)) { goto next_program; } if (!parse_client_hello(skb, offset, data_end, tls_info)) { @@ -249,7 +250,8 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; - if (!is_tls_handshake_server_hello(skb, &usm_ctx->tls_header, offset, data_end)) { + __u8 content_type = usm_ctx->tls_content_type; + if (!is_tls_handshake_server_hello(skb, content_type, offset, data_end)) { goto next_program; } if (!parse_server_hello(skb, offset, data_end, tls_info)) { diff --git a/pkg/network/ebpf/c/protocols/classification/usm-context.h b/pkg/network/ebpf/c/protocols/classification/usm-context.h index 29d2c420060cd5..ce036112fe4118 100644 --- a/pkg/network/ebpf/c/protocols/classification/usm-context.h +++ b/pkg/network/ebpf/c/protocols/classification/usm-context.h @@ -23,7 +23,7 @@ typedef struct { // bit mask with layers that should be skiped u16 routing_skip_layers; classification_prog_t routing_current_program; - tls_record_header_t tls_header; + __u8 tls_content_type; } usm_context_t; // Kernels before 4.7 do not know about per-cpu array maps. diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index d52479fa6951b3..106b7195cd564a 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -429,11 +429,8 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs } // is_tls_handshake_type checks if the handshake type is the expected type (client or server hello) -static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u32 data_end, __u8 expected_handshake_type) { - if (!tls_hdr) { - return false; - } - if (tls_hdr->content_type != TLS_HANDSHAKE) { +static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u8 content_type, __u64 offset, __u32 data_end, __u8 expected_handshake_type) { + if (content_type != TLS_HANDSHAKE) { return false; } @@ -450,13 +447,13 @@ static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, tls_rec } // is_tls_handshake_client_hello checks if the packet is a TLS ClientHello message -static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u32 data_end) { - return is_tls_handshake_type(skb, tls_hdr, offset, data_end, TLS_HANDSHAKE_CLIENT_HELLO); +static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, __u8 content_type, __u64 offset, __u32 data_end) { + return is_tls_handshake_type(skb, content_type, offset, data_end, TLS_HANDSHAKE_CLIENT_HELLO); } // is_tls_handshake_server_hello checks if the packet is a TLS ServerHello message -static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, tls_record_header_t *tls_hdr, __u64 offset, __u32 data_end) { - return is_tls_handshake_type(skb, tls_hdr, offset, data_end, TLS_HANDSHAKE_SERVER_HELLO); +static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, __u8 content_type, __u64 offset, __u32 data_end) { + return is_tls_handshake_type(skb, content_type, offset, data_end, TLS_HANDSHAKE_SERVER_HELLO); } #endif // __TLS_H From d97e272ab11ee2cb3fe1250f707d174172702fc7 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 12 Dec 2024 19:31:47 -0500 Subject: [PATCH 47/53] appease linter --- pkg/network/tracer/tracer_linux_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index 104e61bdc6a51f..d9187fa5e166a9 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2685,7 +2685,7 @@ func (s *TracerSuite) TestTLSClassification() { // or just 0 since the validation doesn't rely on the scenario for this test. return port, tls.VersionTLS12 }, - validation: func(t *testing.T, tr *Tracer, port uint16, scenario uint16) { + validation: func(t *testing.T, tr *Tracer, port uint16, _ uint16) { // Verify that no TLS tags are set for this connection require.Eventually(t, func() bool { payload := getConnections(t, tr) From b063757202f1bba27e892fd99bd0b9a08c100dd5 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 16 Dec 2024 14:50:04 -0500 Subject: [PATCH 48/53] fix typo and unexport vars --- pkg/network/ebpf/c/tracer/tracer.h | 2 +- pkg/network/protocols/tls/types.go | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pkg/network/ebpf/c/tracer/tracer.h b/pkg/network/ebpf/c/tracer/tracer.h index 2eb619c7337807..c23c98b1891628 100644 --- a/pkg/network/ebpf/c/tracer/tracer.h +++ b/pkg/network/ebpf/c/tracer/tracer.h @@ -34,7 +34,7 @@ typedef struct { __u16 cipher_suite; __u8 offered_versions; } tls_info_t; -\ + typedef struct { tls_info_t info; __u64 updated; diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 3d1974f434dc97..2bddcd5776bdeb 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -16,10 +16,10 @@ const ( TagTLSVersion = "tls.version:" TagTLSCipherSuiteID = "tls.cipher_suite_id:" TagTLSClientVersion = "tls.client_version:" - Version10 = "tls_1.0" - Version11 = "tls_1.1" - Version12 = "tls_1.2" - Version13 = "tls_1.3" + version10 = "tls_1.0" + version11 = "tls_1.1" + version12 = "tls_1.2" + version13 = "tls_1.3" ) // Bitmask constants for Offered_versions matching kernelspace definitions @@ -32,18 +32,18 @@ const ( // VersionTags maps TLS versions to tag names for server chosen version (exported for testing) var VersionTags = map[uint16]string{ - tls.VersionTLS10: TagTLSVersion + Version10, - tls.VersionTLS11: TagTLSVersion + Version11, - tls.VersionTLS12: TagTLSVersion + Version12, - tls.VersionTLS13: TagTLSVersion + Version13, + tls.VersionTLS10: TagTLSVersion + version10, + tls.VersionTLS11: TagTLSVersion + version11, + tls.VersionTLS12: TagTLSVersion + version12, + tls.VersionTLS13: TagTLSVersion + version13, } // ClientVersionTags maps TLS versions to tag names for client offered versions (exported for testing) var ClientVersionTags = map[uint16]string{ - tls.VersionTLS10: TagTLSClientVersion + Version10, - tls.VersionTLS11: TagTLSClientVersion + Version11, - tls.VersionTLS12: TagTLSClientVersion + Version12, - tls.VersionTLS13: TagTLSClientVersion + Version13, + tls.VersionTLS10: TagTLSClientVersion + version10, + tls.VersionTLS11: TagTLSClientVersion + version11, + tls.VersionTLS12: TagTLSClientVersion + version12, + tls.VersionTLS13: TagTLSClientVersion + version13, } // Mapping of offered version bitmasks to version constants @@ -119,10 +119,10 @@ func hexCipherSuiteTag(cipherSuite uint16) string { // GetTLSDynamicTags generates dynamic tags based on TLS information func GetTLSDynamicTags(tls *Tags) map[string]struct{} { - tags := make(map[string]struct{}) if tls == nil { return nil } + tags := make(map[string]struct{}) // Server chosen version if tag, ok := VersionTags[tls.ChosenVersion]; ok { From 7b42eb085da159e038e17f7730aa0cf572a93023 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 19 Dec 2024 15:01:32 -0500 Subject: [PATCH 49/53] subset of review comments addressed --- .../classification/protocol-classification.h | 21 +- .../classification/routing-helpers.h | 16 +- .../ebpf/c/protocols/classification/routing.h | 2 +- .../classification/shared-tracer-maps.h | 2 +- .../c/protocols/classification/usm-context.h | 1 + pkg/network/ebpf/c/protocols/tls/tls.h | 371 ++++++++++-------- pkg/network/encoding/marshal/format.go | 3 +- pkg/network/protocols/tls/types.go | 32 +- pkg/network/protocols/tls/types_test.go | 4 +- pkg/network/tracer/tracer_linux_test.go | 9 +- 10 files changed, 255 insertions(+), 206 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 9122c8bf58ff07..145386c65d73df 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -63,10 +63,7 @@ // updates the the protocol stack and adds the current layer to the routing skip list static __always_inline void update_protocol_information(usm_context_t *usm_ctx, protocol_stack_t *stack, protocol_t proto) { set_protocol(stack, proto); - // Mark the current layer as known except for TLS, since there is still metadata to be extracted - if (proto != PROTOCOL_TLS) { - usm_ctx->routing_skip_layers |= proto; - } + usm_ctx->routing_skip_layers |= proto; } // Check if the connections is used for gRPC traffic. @@ -152,7 +149,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct protocol_stack_t *protocol_stack = get_protocol_stack_if_exists(&usm_ctx->tuple); - if (is_fully_classified(protocol_stack)) { + if (is_fully_classified(protocol_stack) || is_protocol_layer_known(protocol_stack, LAYER_ENCRYPTION)) { return; } @@ -171,13 +168,13 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct return; } // TLS classification - update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); - if (tls_hdr.content_type == TLS_APPLICATION_DATA) { + if (tls_hdr.content_type != TLS_HANDSHAKE) { // We can't classify TLS encrypted traffic further, so return early + update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); return; } - // Parse TLS payload + // Parse TLS handshake payload tls_info_t *tags = get_or_create_tls_enhanced_tags(&usm_ctx->tuple); if (tags) { usm_ctx->tls_content_type = tls_hdr.content_type; @@ -227,8 +224,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; - __u8 content_type = usm_ctx->tls_content_type; - if (!is_tls_handshake_client_hello(skb, content_type, offset, data_end)) { + if (!is_tls_handshake_client_hello(skb, usm_ctx->tls_content_type, offset, usm_ctx->skb_info.data_end)) { goto next_program; } if (!parse_client_hello(skb, offset, data_end, tls_info)) { @@ -250,8 +246,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; - __u8 content_type = usm_ctx->tls_content_type; - if (!is_tls_handshake_server_hello(skb, content_type, offset, data_end)) { + if (!is_tls_handshake_server_hello(skb, usm_ctx->tls_content_type, offset, data_end)) { goto next_program; } if (!parse_server_hello(skb, offset, data_end, tls_info)) { @@ -262,6 +257,8 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!protocol_stack) { return; } + update_protocol_information(usm_ctx, protocol_stack, PROTOCOL_TLS); + // We can't classify TLS encrypted traffic further, so return early return; next_program: diff --git a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h index 83b27ba19644c2..79dc2d79bb2745 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h @@ -33,7 +33,21 @@ static __always_inline u16 get_current_program_layer(classification_prog_t curre return 0; } -// next_layer_entrypoint returns the entrypoint of the next layer that should be executed +// debug for if we don't reorder the programs +// static __always_inline u16 get_current_program_layer(classification_prog_t current_program) { +// if (current_program > __PROG_APPLICATION && current_program < __PROG_API) { +// return LAYER_APPLICATION_BIT; +// } +// if (current_program > __PROG_API && current_program < __PROG_ENCRYPTION) { +// return LAYER_ENCRYPTION_BIT; +// } +// if (current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX) { +// return LAYER_API_BIT; +// } + +// return 0; +// } + static __always_inline classification_prog_t next_layer_entrypoint(usm_context_t *usm_ctx) { u16 to_skip = usm_ctx->routing_skip_layers; diff --git a/pkg/network/ebpf/c/protocols/classification/routing.h b/pkg/network/ebpf/c/protocols/classification/routing.h index 131fa15a918195..f801003f202bdc 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing.h +++ b/pkg/network/ebpf/c/protocols/classification/routing.h @@ -70,7 +70,7 @@ static __always_inline void init_routing_cache(usm_context_t *usm_ctx, protocol_ if (stack->layer_api || !has_available_program(__PROG_API)) { usm_ctx->routing_skip_layers |= LAYER_API_BIT; } - if (stack->flags == FLAG_FULLY_CLASSIFIED || !has_available_program(__PROG_ENCRYPTION)) { + if (stack->layer_encryption || !has_available_program(__PROG_ENCRYPTION)) { usm_ctx->routing_skip_layers |= LAYER_ENCRYPTION_BIT; } } diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index d8e1c898f47b73..5d67ffa013f777 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -168,7 +168,7 @@ static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t tls_info_wrapper_t empty_tags_wrapper = {}; empty_tags_wrapper.updated = bpf_ktime_get_ns(); - bpf_map_update_elem(&tls_enhanced_tags, &normalized_tup, &empty_tags_wrapper, BPF_ANY); + bpf_map_update_with_telemetry(tls_enhanced_tags, &normalized_tup, &empty_tags_wrapper, BPF_ANY); tls_info_wrapper_t *wrapper_ptr = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); if (!wrapper_ptr) { return NULL; diff --git a/pkg/network/ebpf/c/protocols/classification/usm-context.h b/pkg/network/ebpf/c/protocols/classification/usm-context.h index ce036112fe4118..87027cf01d700a 100644 --- a/pkg/network/ebpf/c/protocols/classification/usm-context.h +++ b/pkg/network/ebpf/c/protocols/classification/usm-context.h @@ -15,6 +15,7 @@ typedef struct { size_t size; } classification_buffer_t; +// TODO: rename this struct to `classification_context_t` typedef struct { struct __sk_buff *owner; conn_tuple_t tuple; diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index 106b7195cd564a..b9e17ed208153a 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -11,36 +11,48 @@ #define TLS_VERSION12 0x0303 #define TLS_VERSION13 0x0304 -// TLS Content Types (RFC 5246 Section 6.2.1) +// TLS Content Types (https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer) #define TLS_HANDSHAKE 0x16 #define TLS_APPLICATION_DATA 0x17 +#define TLS_CHANGE_CIPHER_SPEC 0x14 +#define TLS_ALERT 0x15 // TLS Handshake Types #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 #define TLS_HANDSHAKE_SERVER_HELLO 0x02 // Bitmask constants for offered versions -#define TLS_VERSION10_BIT 0x01 -#define TLS_VERSION11_BIT 0x02 -#define TLS_VERSION12_BIT 0x04 -#define TLS_VERSION13_BIT 0x08 +#define TLS_VERSION10_BIT 1 << 0 +#define TLS_VERSION11_BIT 1 << 1 +#define TLS_VERSION12_BIT 1 << 2 +#define TLS_VERSION13_BIT 1 << 3 // Maximum number of extensions to parse when looking for SUPPORTED_VERSIONS_EXTENSION #define MAX_EXTENSIONS 16 +// The supported_versions extension for TLS 1.3 is described in RFC 8446 Section 4.2.1 #define SUPPORTED_VERSIONS_EXTENSION 0x002B // Maximum TLS record payload size (16 KB) #define TLS_MAX_PAYLOAD_LENGTH (1 << 14) -// Field Lengths -#define TLS_HANDSHAKE_LENGTH 3 // Handshake length is 3 bytes -#define RANDOM_LENGTH 32 // Random field length in Client/Server Hello -#define PROTOCOL_VERSION_LENGTH 2 // Protocol version field is 2 bytes -#define SESSION_ID_LENGTH 1 // Session ID length field is 1 byte -#define CIPHER_SUITES_LENGTH 2 // Cipher Suites length field is 2 bytes -#define COMPRESSION_METHODS_LENGTH 1 // Compression Methods length field is 1 byte -#define EXTENSION_TYPE_LENGTH 2 // Extension Type field is 2 bytes -#define EXTENSION_LENGTH_FIELD 2 // Extension Length field is 2 bytes +// The following field lengths and message formats are defined by the TLS specifications +// For TLS 1.2 (and earlier) see: +// RFC 5246 - The Transport Layer Security (TLS) Protocol Version 1.2 +// https://tools.ietf.org/html/rfc5246 +// Particularly Section 7.4 details handshake messages and their fields, and Section 6.2.1 +// covers the TLS record layer. +// For TLS 1.3, see: +// RFC 8446 - The Transport Layer Security (TLS) Protocol Version 1.3 +// https://tools.ietf.org/html/rfc8446 +// Many handshake structures are similar, but some extensions (like supported_versions) are defined here +#define TLS_HANDSHAKE_LENGTH 3 // Handshake length is 3 bytes (RFC 5246 Section 7.4) +#define RANDOM_LENGTH 32 // Random field length in Client/Server Hello (RFC 5246 Section 7.4.1.2) +#define PROTOCOL_VERSION_LENGTH 2 // Protocol version field is 2 bytes (RFC 5246 Section 6.2.1) +#define SESSION_ID_LENGTH 1 // Session ID length field is 1 byte (RFC 5246 Section 7.4.1.2) +#define CIPHER_SUITES_LENGTH 2 // Cipher Suites length field is 2 bytes (RFC 5246 Section 7.4.1.2) +#define COMPRESSION_METHODS_LENGTH 1 // Compression Methods length field is 1 byte (RFC 5246 Section 7.4.1.2) +#define EXTENSION_TYPE_LENGTH 2 // Extension Type field is 2 bytes (RFC 5246 Section 7.4.1.4) +#define EXTENSION_LENGTH_FIELD 2 // Extension Length field is 2 bytes (RFC 5246 Section 7.4.1.4) // For single-byte fields (list lengths, etc.) #define SINGLE_BYTE_LENGTH 1 @@ -48,7 +60,7 @@ // Minimum extension header length = Extension Type (2 bytes) + Extension Length (2 bytes) = 4 bytes #define MIN_EXTENSION_HEADER_LENGTH (EXTENSION_TYPE_LENGTH + EXTENSION_LENGTH_FIELD) -// Maximum number of supported versions we unroll for +// Maximum number of supported versions we unroll for (all TLS versions) #define MAX_SUPPORTED_VERSIONS 4 // TLS record layer header structure (RFC 5246) @@ -122,18 +134,21 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 } // is_tls checks if the packet is a TLS packet and reads the TLS record header +// Uses RFC 5246 Section 6.2.1 (https://www.rfc-editor.org/rfc/rfc5246#page-19) for record structure and content types static __always_inline bool is_tls(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { - // Read and validate the TLS record header if (!read_tls_record_header(skb, header_offset, data_end, tls_hdr)) { return false; } // Validate content type - return tls_hdr->content_type == TLS_HANDSHAKE || tls_hdr->content_type == TLS_APPLICATION_DATA; + __u8 ct = tls_hdr->content_type; + return ct == TLS_HANDSHAKE || ct == TLS_APPLICATION_DATA || ct == TLS_CHANGE_CIPHER_SPEC || ct == TLS_ALERT; } -// Parses the TLS handshake header to extract handshake_length and protocol_version -static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u64 *offset, __u32 data_end, __u32 *handshake_length, __u16 *protocol_version) { +// parse_tls_handshake_header extracts handshake_length and protocol_version from the handshake message. +// The handshake header (RFC 5246 Section 7.4, https://tools.ietf.org/html/rfc5246) starts with: +// handshake_type (1 byte), length (3 bytes), then protocol_version in case of Client/Server Hello. +static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u32 *handshake_length, __u16 *protocol_version) { *offset += SINGLE_BYTE_LENGTH; // Move past handshake type (1 byte) // Read handshake length (3 bytes) @@ -168,156 +183,57 @@ static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __ return true; } -// skip_random_and_session_id Skips the Random (32 bytes) and Session ID from the TLS hello messages by incrementing the offset pointer -static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __u64 *offset, __u32 data_end) { - // Skip Random (32 bytes) - *offset += RANDOM_LENGTH; +// parse_client_hello parses the ClientHello message and populates tags +// Reference: RFC 5246 Section 7.4.1.2 (Client Hello), https://tools.ietf.org/html/rfc5246 +// Structure (simplified): +// handshake_type (1 byte), length (3 bytes), version (2 bytes), random(32 bytes), session_id_length(1 byte), session_id(variable), cipher_suites_length(2 bytes), cipher_suites(variable), compression_methods_length(1 byte), compression_methods(variable), extensions_length(2 bytes), extensions(variable) +static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end, tls_info_t *tags) { + __u32 handshake_length; + __u16 client_version; - // Read Session ID Length (1 byte) - if (*offset + SESSION_ID_LENGTH > data_end) { - return false; - } - __u8 session_id_length; - if (bpf_skb_load_bytes(skb, *offset, &session_id_length, SESSION_ID_LENGTH) < 0) { + if (!parse_tls_handshake_header(skb, &offset, data_end, &handshake_length, &client_version)) { return false; } - *offset += SESSION_ID_LENGTH; - - // Skip Session ID - *offset += session_id_length; - - // Ensure we don't read beyond the packet - return *offset <= data_end; -} - -// parse_supported_versions_extension looks for the supported_versions extension in the ClientHello or ServerHello and populates tags -static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u64 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { - if (is_client_hello) { - // Read supported version list length (1 byte) - if (*offset + SINGLE_BYTE_LENGTH > data_end || *offset + SINGLE_BYTE_LENGTH > extensions_end) { - return false; - } - __u8 sv_list_length; - if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, SINGLE_BYTE_LENGTH) < 0) { - return false; - } - *offset += SINGLE_BYTE_LENGTH; - - if (*offset + sv_list_length > data_end || *offset + sv_list_length > extensions_end) { - return false; - } - - // Parse the list of supported versions (2 bytes each) - __u8 sv_offset = 0; - __u16 sv_version; - #pragma unroll(MAX_SUPPORTED_VERSIONS) - for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { - if (sv_offset + 1 >= sv_list_length) { - break; - } - // Each supported version is 2 bytes - if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { - return false; - } - - if (bpf_skb_load_bytes(skb, *offset, &sv_version, PROTOCOL_VERSION_LENGTH) < 0) { - return false; - } - sv_version = bpf_ntohs(sv_version); - *offset += PROTOCOL_VERSION_LENGTH; - - set_tls_offered_version(tags, sv_version); - sv_offset += PROTOCOL_VERSION_LENGTH; - } - } else { - // ServerHello - // The selected_version field is 2 bytes - if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { - return false; - } - - // Read selected version (2 bytes) - __u16 selected_version; - if (bpf_skb_load_bytes(skb, *offset, &selected_version, PROTOCOL_VERSION_LENGTH) < 0) { - return false; - } - selected_version = bpf_ntohs(selected_version); - *offset += PROTOCOL_VERSION_LENGTH; - - tags->chosen_version = selected_version; - } - - return true; -} - -// parse_tls_extensions parses TLS extensions and looks for the supported_versions extension -static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u64 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { - __u16 extension_type; - __u16 extension_length; - #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (*offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { - break; - } + set_tls_offered_version(tags, client_version); - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, *offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { - return false; - } - extension_type = bpf_ntohs(extension_type); - *offset += EXTENSION_TYPE_LENGTH; + // If client_version < TLS 1.2, no extensions to parse + if (client_version != TLS_VERSION12) { + // Skip Random (32 bytes) + offset += RANDOM_LENGTH; - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, *offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { + // Session ID Length (1 byte) + if (offset + SESSION_ID_LENGTH > data_end) { return false; } - extension_length = bpf_ntohs(extension_length); - *offset += EXTENSION_LENGTH_FIELD; - - if (*offset + extension_length > data_end || *offset + extension_length > extensions_end) { + __u8 session_id_length; + if (bpf_skb_load_bytes(skb, offset, &session_id_length, SESSION_ID_LENGTH) < 0) { return false; } + offset += SESSION_ID_LENGTH; - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - if (!parse_supported_versions_extension(skb, offset, data_end, extensions_end, tags, is_client_hello)) { - return false; - } - } else { - // Skip other extensions - *offset += extension_length; - } - - if (*offset >= extensions_end) { - break; - } + // Skip Session ID + offset += session_id_length; + return true; } - return true; -} - -// parse_client_hello parses the ClientHello message and populates tags -static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offset, __u32 data_end, tls_info_t *tags) { - __u32 handshake_length; - __u16 client_version; + // TLS 1.2 case: + // Skip Random (32 bytes) + offset += RANDOM_LENGTH; - if (!parse_tls_handshake_header(skb, &offset, data_end, &handshake_length, &client_version)) { + // Session ID Length (1 byte) + if (offset + SESSION_ID_LENGTH > data_end) { return false; } - - set_tls_offered_version(tags, client_version); - - // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), - // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. - if (client_version != TLS_VERSION12) { - return true; - } - - if (!skip_random_and_session_id(skb, &offset, data_end)) { + __u8 session_id_length; + if (bpf_skb_load_bytes(skb, offset, &session_id_length, SESSION_ID_LENGTH) < 0) { return false; } + offset += SESSION_ID_LENGTH; + + // Skip Session ID + offset += session_id_length; - // Read Cipher Suites Length (2 bytes) if (offset + CIPHER_SUITES_LENGTH > data_end) { return false; } @@ -331,7 +247,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs // Skip Cipher Suites offset += cipher_suites_length; - // Read Compression Methods Length (1 byte) + // Compression Methods Length (1 byte) if (offset + COMPRESSION_METHODS_LENGTH > data_end) { return false; } @@ -344,12 +260,10 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs // Skip Compression Methods offset += compression_methods_length; - // Check if extensions are present + // Extensions Length (2 bytes) if (offset + EXTENSION_LENGTH_FIELD > data_end) { return false; } - - // Read Extensions Length (2 bytes) __u16 extensions_length; if (bpf_skb_load_bytes(skb, offset, &extensions_length, EXTENSION_LENGTH_FIELD) < 0) { return false; @@ -363,11 +277,85 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u64 offs __u64 extensions_end = offset + extensions_length; - return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, true); + // Inline extension parsing: + __u16 extension_type; + __u16 extension_length; + #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { + break; + } + + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { + return false; + } + extension_type = bpf_ntohs(extension_type); + offset += EXTENSION_TYPE_LENGTH; + + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { + return false; + } + extension_length = bpf_ntohs(extension_length); + offset += EXTENSION_LENGTH_FIELD; + + if (offset + extension_length > data_end || offset + extension_length > extensions_end) { + return false; + } + + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + if (offset + SINGLE_BYTE_LENGTH > data_end || offset + SINGLE_BYTE_LENGTH > extensions_end) { + return false; + } + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, offset, &sv_list_length, SINGLE_BYTE_LENGTH) < 0) { + return false; + } + offset += SINGLE_BYTE_LENGTH; + + if (offset + sv_list_length > data_end || offset + sv_list_length > extensions_end) { + return false; + } + + __u8 sv_offset = 0; + __u16 sv_version; + #pragma unroll(MAX_SUPPORTED_VERSIONS) + for (int j = 0; j < MAX_SUPPORTED_VERSIONS; j++) { + if (sv_offset + 1 >= sv_list_length) { + break; + } + if (offset + PROTOCOL_VERSION_LENGTH > data_end) { + return false; + } + + if (bpf_skb_load_bytes(skb, offset, &sv_version, PROTOCOL_VERSION_LENGTH) < 0) { + return false; + } + sv_version = bpf_ntohs(sv_version); + offset += PROTOCOL_VERSION_LENGTH; + + set_tls_offered_version(tags, sv_version); + sv_offset += PROTOCOL_VERSION_LENGTH; + } + } else { + // Skip other extensions + offset += extension_length; + } + + if (offset >= extensions_end) { + break; + } + } + + return true; } // parse_server_hello parses the ServerHello message and populates tags -static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offset, __u32 data_end, tls_info_t *tags) { +// Reference: RFC 5246 Section 7.4.1.2 (Server Hello), https://tools.ietf.org/html/rfc5246 +// Structure (simplified): +// handshake_type(1), length(3), version(2), random(32), session_id_length(1), session_id(variable), cipher_suite(2), compression_method(1), extensions_length(2), extensions(variable) +static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end, tls_info_t *tags) { __u32 handshake_length; __u16 server_version; @@ -380,9 +368,17 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs // The actual version is embedded in the supported_versions extension tags->chosen_version = server_version; - if (!skip_random_and_session_id(skb, &offset, data_end)) { + offset += RANDOM_LENGTH; // Skip Random + + if (offset + SESSION_ID_LENGTH > data_end) { + return false; + } + __u8 session_id_length; + if (bpf_skb_load_bytes(skb, offset, &session_id_length, SESSION_ID_LENGTH) < 0) { return false; } + offset += SESSION_ID_LENGTH; + offset += session_id_length; // Skip Session ID // Read Cipher Suite (2 bytes) if (offset + CIPHER_SUITES_LENGTH > data_end) { @@ -395,8 +391,7 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs cipher_suite = bpf_ntohs(cipher_suite); offset += CIPHER_SUITES_LENGTH; - // Skip Compression Method (1 byte) - offset += COMPRESSION_METHODS_LENGTH; + offset += COMPRESSION_METHODS_LENGTH; // Skip Compression Method tags->cipher_suite = cipher_suite; @@ -425,11 +420,61 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u64 offs __u64 extensions_end = offset + extensions_length; - return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, false); + __u16 extension_type; + __u16 extension_length; + #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { + break; + } + + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { + return false; + } + extension_type = bpf_ntohs(extension_type); + offset += EXTENSION_TYPE_LENGTH; + + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { + return false; + } + extension_length = bpf_ntohs(extension_length); + offset += EXTENSION_LENGTH_FIELD; + + if (offset + extension_length > data_end || offset + extension_length > extensions_end) { + return false; + } + + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + // Inline parse_supported_versions_extension for ServerHello + if (offset + PROTOCOL_VERSION_LENGTH > data_end) { + return false; + } + + __u16 selected_version; + if (bpf_skb_load_bytes(skb, offset, &selected_version, PROTOCOL_VERSION_LENGTH) < 0) { + return false; + } + selected_version = bpf_ntohs(selected_version); + offset += PROTOCOL_VERSION_LENGTH; + + tags->chosen_version = selected_version; + } else { + // Skip other extensions + offset += extension_length; + } + + if (offset >= extensions_end) { + break; + } + } + + return true; } // is_tls_handshake_type checks if the handshake type is the expected type (client or server hello) -static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u8 content_type, __u64 offset, __u32 data_end, __u8 expected_handshake_type) { +static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u8 content_type, __u32 offset, __u32 data_end, __u8 expected_handshake_type) { if (content_type != TLS_HANDSHAKE) { return false; } @@ -447,12 +492,12 @@ static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u8 co } // is_tls_handshake_client_hello checks if the packet is a TLS ClientHello message -static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, __u8 content_type, __u64 offset, __u32 data_end) { +static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, __u8 content_type, __u32 offset, __u32 data_end) { return is_tls_handshake_type(skb, content_type, offset, data_end, TLS_HANDSHAKE_CLIENT_HELLO); } // is_tls_handshake_server_hello checks if the packet is a TLS ServerHello message -static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, __u8 content_type, __u64 offset, __u32 data_end) { +static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, __u8 content_type, __u32 offset, __u32 data_end) { return is_tls_handshake_type(skb, content_type, offset, data_end, TLS_HANDSHAKE_SERVER_HELLO); } diff --git a/pkg/network/encoding/marshal/format.go b/pkg/network/encoding/marshal/format.go index dc106447a1bc2c..2e0d40955191a7 100644 --- a/pkg/network/encoding/marshal/format.go +++ b/pkg/network/encoding/marshal/format.go @@ -13,7 +13,6 @@ import ( model "github.com/DataDog/agent-payload/v5/process" "github.com/DataDog/datadog-agent/pkg/network" - "github.com/DataDog/datadog-agent/pkg/network/protocols/tls" "github.com/DataDog/datadog-agent/pkg/process/util" ) @@ -121,7 +120,7 @@ func FormatConnection(builder *model.ConnectionBuilder, conn network.ConnectionS httpStaticTags, httpDynamicTags := httpEncoder.GetHTTPAggregationsAndTags(conn, builder) http2StaticTags, http2DynamicTags := http2Encoder.WriteHTTP2AggregationsAndTags(conn, builder) - tlsDynamicTags := tls.GetTLSDynamicTags(&conn.TLSTags) + tlsDynamicTags := conn.TLSTags.GetDynamicTags() staticTags := httpStaticTags | http2StaticTags dynamicTags := mergeDynamicTags(httpDynamicTags, http2DynamicTags, tlsDynamicTags) diff --git a/pkg/network/protocols/tls/types.go b/pkg/network/protocols/tls/types.go index 2bddcd5776bdeb..c3014c3f65f3e5 100644 --- a/pkg/network/protocols/tls/types.go +++ b/pkg/network/protocols/tls/types.go @@ -81,6 +81,9 @@ func (t *Tags) MergeWith(that Tags) { // IsEmpty returns true if all fields are zero func (t *Tags) IsEmpty() bool { + if t == nil { + return true + } return t.ChosenVersion == 0 && t.CipherSuite == 0 && t.OfferedVersions == 0 } @@ -91,7 +94,7 @@ func (t *Tags) String() string { // parseOfferedVersions parses the Offered_versions bitmask into a slice of version strings func parseOfferedVersions(offeredVersions uint8) []string { - versions := make([]string, 0, 4) + versions := make([]string, 0, len(offeredVersionBitmask)) for _, ov := range offeredVersionBitmask { if (offeredVersions & ov.bitMask) != 0 { if name := ClientVersionTags[ov.version]; name != "" { @@ -103,40 +106,29 @@ func parseOfferedVersions(offeredVersions uint8) []string { } func hexCipherSuiteTag(cipherSuite uint16) string { - // Preallocate a buffer for "0x" + 4 hex digits = 6 chars - var buf [6]byte - buf[0] = '0' - buf[1] = 'x' - hex := "0123456789ABCDEF" - - buf[2] = hex[(cipherSuite>>12)&0xF] - buf[3] = hex[(cipherSuite>>8)&0xF] - buf[4] = hex[(cipherSuite>>4)&0xF] - buf[5] = hex[cipherSuite&0xF] - - return TagTLSCipherSuiteID + string(buf[:]) + return fmt.Sprintf("%s0x%04X", TagTLSCipherSuiteID, cipherSuite) } -// GetTLSDynamicTags generates dynamic tags based on TLS information -func GetTLSDynamicTags(tls *Tags) map[string]struct{} { - if tls == nil { +// GetDynamicTags generates dynamic tags based on TLS information +func (t *Tags) GetDynamicTags() map[string]struct{} { + if t.IsEmpty() { return nil } tags := make(map[string]struct{}) // Server chosen version - if tag, ok := VersionTags[tls.ChosenVersion]; ok { + if tag, ok := VersionTags[t.ChosenVersion]; ok { tags[tag] = struct{}{} } // Client offered versions - for _, versionName := range parseOfferedVersions(tls.OfferedVersions) { + for _, versionName := range parseOfferedVersions(t.OfferedVersions) { tags[versionName] = struct{}{} } // Cipher suite ID as hex string - if tls.CipherSuite != 0 { - tags[hexCipherSuiteTag(tls.CipherSuite)] = struct{}{} + if t.CipherSuite != 0 { + tags[hexCipherSuiteTag(t.CipherSuite)] = struct{}{} } return tags diff --git a/pkg/network/protocols/tls/types_test.go b/pkg/network/protocols/tls/types_test.go index 359ff509d6b263..979cc2bbdba327 100644 --- a/pkg/network/protocols/tls/types_test.go +++ b/pkg/network/protocols/tls/types_test.go @@ -119,9 +119,9 @@ func TestGetTLSDynamicTags(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - result := GetTLSDynamicTags(test.tlsTags) + result := test.tlsTags.GetDynamicTags() if !reflect.DeepEqual(result, test.expected) { - t.Errorf("GetTLSDynamicTags(%v) = %v; want %v", test.tlsTags, result, test.expected) + t.Errorf("GetDynamicTags(%v) = %v; want %v", test.tlsTags, result, test.expected) } }) } diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index d9187fa5e166a9..2870ffc868720a 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2705,9 +2705,10 @@ func (s *TracerSuite) TestTLSClassification() { if ebpftest.GetBuildMode() == ebpftest.Fentry { t.Skip("protocol classification not supported for fentry tracer") } - t.Cleanup(func() { tr.RemoveClient(clientID) }) - t.Cleanup(func() { _ = tr.Pause() }) - + t.Cleanup(func() { + tr.RemoveClient(clientID) + _ = tr.Pause() + }) tr.RemoveClient(clientID) require.NoError(t, tr.RegisterClient(clientID)) require.NoError(t, tr.Resume(), "enable probes - before post tracer") @@ -2722,7 +2723,7 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo payload := getConnections(t, tr) for _, c := range payload.Conns { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { - tlsTags := ddtls.GetTLSDynamicTags(&c.TLSTags) + tlsTags := c.TLSTags.GetDynamicTags() // Check that the cipher suite ID tag is present cipherSuiteTagFound := false From 1cefe66d8bad983bd3f02b716019e18dd31cfcc1 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 19 Dec 2024 16:51:15 -0500 Subject: [PATCH 50/53] numerous changes to docs and removal of change to usm_ctx --- .../ebpf/c/protocols/classification/defs.h | 8 +- .../classification/protocol-classification.h | 5 +- .../classification/routing-helpers.h | 29 +- .../ebpf/c/protocols/classification/routing.h | 3 +- .../classification/shared-tracer-maps.h | 30 -- .../c/protocols/classification/usm-context.h | 2 - pkg/network/ebpf/c/protocols/tls/tls.h | 359 +++++++++--------- pkg/network/ebpf/c/tracer/maps.h | 3 + pkg/network/ebpf/c/tracer/stats.h | 29 ++ pkg/network/ebpf/kprobe_types_linux.go | 10 +- 10 files changed, 227 insertions(+), 251 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/defs.h b/pkg/network/ebpf/c/protocols/classification/defs.h index c91dc427d9e18c..235d638d2f5b59 100644 --- a/pkg/network/ebpf/c/protocols/classification/defs.h +++ b/pkg/network/ebpf/c/protocols/classification/defs.h @@ -114,10 +114,6 @@ typedef struct { typedef enum { CLASSIFICATION_PROG_UNKNOWN = 0, - __PROG_ENCRYPTION, - // Encryption classification programs go here - CLASSIFICATION_TLS_CLIENT_PROG, - CLASSIFICATION_TLS_SERVER_PROG, __PROG_APPLICATION, // Application classification programs go here CLASSIFICATION_QUEUES_PROG, @@ -125,6 +121,10 @@ typedef enum { __PROG_API, // API classification programs go here CLASSIFICATION_GRPC_PROG, + __PROG_ENCRYPTION, + // Encryption classification programs go here + CLASSIFICATION_TLS_CLIENT_PROG, + CLASSIFICATION_TLS_SERVER_PROG, CLASSIFICATION_PROG_MAX, } classification_prog_t; diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 145386c65d73df..4c62c543350ca6 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -177,7 +177,6 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct // Parse TLS handshake payload tls_info_t *tags = get_or_create_tls_enhanced_tags(&usm_ctx->tuple); if (tags) { - usm_ctx->tls_content_type = tls_hdr.content_type; // The packet is a TLS handshake, so trigger some tail calls // to extract metadata from the payload goto next_program; @@ -224,7 +223,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; - if (!is_tls_handshake_client_hello(skb, usm_ctx->tls_content_type, offset, usm_ctx->skb_info.data_end)) { + if (!is_tls_handshake_client_hello(skb, offset, usm_ctx->skb_info.data_end)) { goto next_program; } if (!parse_client_hello(skb, offset, data_end, tls_info)) { @@ -246,7 +245,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha } __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; - if (!is_tls_handshake_server_hello(skb, usm_ctx->tls_content_type, offset, data_end)) { + if (!is_tls_handshake_server_hello(skb, offset, data_end)) { goto next_program; } if (!parse_server_hello(skb, offset, data_end, tls_info)) { diff --git a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h index 79dc2d79bb2745..944a8b0f049342 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h @@ -20,46 +20,31 @@ static __always_inline bool has_available_program(classification_prog_t current_ // get_current_program_layer returns the layer bit of the current program static __always_inline u16 get_current_program_layer(classification_prog_t current_program) { - if (current_program > __PROG_ENCRYPTION && current_program < __PROG_APPLICATION) { - return LAYER_ENCRYPTION_BIT; - } if (current_program > __PROG_APPLICATION && current_program < __PROG_API) { return LAYER_APPLICATION_BIT; } - if (current_program > __PROG_API && current_program < CLASSIFICATION_PROG_MAX) { + if (current_program > __PROG_API && current_program < __PROG_ENCRYPTION) { + return LAYER_ENCRYPTION_BIT; + } + if (current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX) { return LAYER_API_BIT; } return 0; } -// debug for if we don't reorder the programs -// static __always_inline u16 get_current_program_layer(classification_prog_t current_program) { -// if (current_program > __PROG_APPLICATION && current_program < __PROG_API) { -// return LAYER_APPLICATION_BIT; -// } -// if (current_program > __PROG_API && current_program < __PROG_ENCRYPTION) { -// return LAYER_ENCRYPTION_BIT; -// } -// if (current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX) { -// return LAYER_API_BIT; -// } - -// return 0; -// } - static __always_inline classification_prog_t next_layer_entrypoint(usm_context_t *usm_ctx) { u16 to_skip = usm_ctx->routing_skip_layers; + if (!(to_skip&LAYER_ENCRYPTION_BIT)) { + return __PROG_ENCRYPTION+1; + } if (!(to_skip&LAYER_APPLICATION_BIT)) { return __PROG_APPLICATION+1; } if (!(to_skip&LAYER_API_BIT)) { return __PROG_API+1; } - if (!(to_skip&LAYER_ENCRYPTION_BIT)) { - return __PROG_ENCRYPTION+1; - } return CLASSIFICATION_PROG_UNKNOWN; } diff --git a/pkg/network/ebpf/c/protocols/classification/routing.h b/pkg/network/ebpf/c/protocols/classification/routing.h index f801003f202bdc..9a4b534553b35c 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing.h +++ b/pkg/network/ebpf/c/protocols/classification/routing.h @@ -61,8 +61,7 @@ static __always_inline void init_routing_cache(usm_context_t *usm_ctx, protocol_ } // We skip a given layer in two cases: - // 1) If the protocol for that layer is known, - // except for encryption as it still needs to be traversed for metadata + // 1) If the protocol for that layer is known // 2) If there are no programs registered for that layer if (stack->layer_application || !has_available_program(__PROG_APPLICATION)) { usm_ctx->routing_skip_layers |= LAYER_APPLICATION_BIT; diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index 5d67ffa013f777..7ca913c722f3a7 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -10,9 +10,6 @@ // classification procedures on the same connection BPF_HASH_MAP(connection_protocol, conn_tuple_t, protocol_stack_wrapper_t, 0) -// Map to store extra information about TLS connections like version, cipher, etc. -BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_info_wrapper_t, 1) - static __always_inline bool is_protocol_classification_supported() { __u64 val = 0; LOAD_CONSTANT("protocol_classification_enabled", val); @@ -149,33 +146,6 @@ __maybe_unused static __always_inline void delete_protocol_stack(conn_tuple_t* n bpf_map_delete_elem(&connection_protocol, normalized_tuple); } -static __always_inline tls_info_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { - conn_tuple_t normalized_tup = *tuple; - normalize_tuple(&normalized_tup); - tls_info_wrapper_t *wrapper = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); - if (!wrapper) { - return NULL; - } - wrapper->updated = bpf_ktime_get_ns(); - return &wrapper->info; -} -static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { - tls_info_t *tags = get_tls_enhanced_tags(tuple); - if (!tags) { - conn_tuple_t normalized_tup = *tuple; - normalize_tuple(&normalized_tup); - tls_info_wrapper_t empty_tags_wrapper = {}; - empty_tags_wrapper.updated = bpf_ktime_get_ns(); - - bpf_map_update_with_telemetry(tls_enhanced_tags, &normalized_tup, &empty_tags_wrapper, BPF_ANY); - tls_info_wrapper_t *wrapper_ptr = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); - if (!wrapper_ptr) { - return NULL; - } - tags = &wrapper_ptr->info; - } - return tags; -} #endif diff --git a/pkg/network/ebpf/c/protocols/classification/usm-context.h b/pkg/network/ebpf/c/protocols/classification/usm-context.h index 87027cf01d700a..47f637c1291cb8 100644 --- a/pkg/network/ebpf/c/protocols/classification/usm-context.h +++ b/pkg/network/ebpf/c/protocols/classification/usm-context.h @@ -15,7 +15,6 @@ typedef struct { size_t size; } classification_buffer_t; -// TODO: rename this struct to `classification_context_t` typedef struct { struct __sk_buff *owner; conn_tuple_t tuple; @@ -24,7 +23,6 @@ typedef struct { // bit mask with layers that should be skiped u16 routing_skip_layers; classification_prog_t routing_current_program; - __u8 tls_content_type; } usm_context_t; // Kernels before 4.7 do not know about per-cpu array maps. diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index b9e17ed208153a..a056794c78bb44 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -12,10 +12,10 @@ #define TLS_VERSION13 0x0304 // TLS Content Types (https://www.rfc-editor.org/rfc/rfc5246#page-19 6.2. Record Layer) -#define TLS_HANDSHAKE 0x16 -#define TLS_APPLICATION_DATA 0x17 +#define TLS_HANDSHAKE 0x16 +#define TLS_APPLICATION_DATA 0x17 #define TLS_CHANGE_CIPHER_SPEC 0x14 -#define TLS_ALERT 0x15 +#define TLS_ALERT 0x15 // TLS Handshake Types #define TLS_HANDSHAKE_CLIENT_HELLO 0x01 @@ -106,6 +106,8 @@ static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 } // read_tls_record_header reads the TLS record header from the packet +// Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 +// Validates the record header fields (content_type, version, length) and checks for correctness within packet bounds. static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { // Ensure there's enough space for TLS record header if (header_offset + sizeof(tls_record_header_t) > data_end) { @@ -133,9 +135,11 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 return header_offset + sizeof(tls_record_header_t) + tls_hdr->length <= data_end; } -// is_tls checks if the packet is a TLS packet and reads the TLS record header -// Uses RFC 5246 Section 6.2.1 (https://www.rfc-editor.org/rfc/rfc5246#page-19) for record structure and content types +// is_tls checks if the packet is a TLS packet by reading and validating the TLS record header +// Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 +// Validates that content_type matches known TLS types (Handshake, Application Data, etc.). static __always_inline bool is_tls(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { + // Read and validate the TLS record header if (!read_tls_record_header(skb, header_offset, data_end, tls_hdr)) { return false; } @@ -145,9 +149,10 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 header_offset, _ return ct == TLS_HANDSHAKE || ct == TLS_APPLICATION_DATA || ct == TLS_CHANGE_CIPHER_SPEC || ct == TLS_ALERT; } -// parse_tls_handshake_header extracts handshake_length and protocol_version from the handshake message. -// The handshake header (RFC 5246 Section 7.4, https://tools.ietf.org/html/rfc5246) starts with: -// handshake_type (1 byte), length (3 bytes), then protocol_version in case of Client/Server Hello. +// parse_tls_handshake_header extracts handshake_length and protocol_version from a TLS handshake message +// References: +// - RFC 5246 Section 7.4 (Handshake Protocol Overview), https://tools.ietf.org/html/rfc5246#section-7.4 +// For ClientHello and ServerHello, this includes parsing the handshake type (skipped prior) and the 3-byte length field, followed by a 2-byte protocol version field. static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u32 *handshake_length, __u16 *protocol_version) { *offset += SINGLE_BYTE_LENGTH; // Move past handshake type (1 byte) @@ -183,6 +188,147 @@ static __always_inline bool parse_tls_handshake_header(struct __sk_buff *skb, __ return true; } +// skip_random_and_session_id Skips the Random (32 bytes) and the Session ID from the TLS Hello messages +// References: +// - RFC 5246 Section 7.4.1.2 (Client Hello and Server Hello): https://tools.ietf.org/html/rfc5246#section-7.4.1.2 +// ClientHello and ServerHello contain a "random" field (32 bytes) followed by a "session_id_length" (1 byte) +// and a session_id of that length. This helper increments the offset accordingly after reading and skipping these fields. +static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __u32 *offset, __u32 data_end) { + // Skip Random (32 bytes) + *offset += RANDOM_LENGTH; + + // Read Session ID Length (1 byte) + if (*offset + SESSION_ID_LENGTH > data_end) { + return false; + } + __u8 session_id_length; + if (bpf_skb_load_bytes(skb, *offset, &session_id_length, SESSION_ID_LENGTH) < 0) { + return false; + } + *offset += SESSION_ID_LENGTH; + + // Skip Session ID + *offset += session_id_length; + + // Ensure we don't read beyond the packet + return *offset <= data_end; +} + +// parse_supported_versions_extension looks for the supported_versions extension in the ClientHello or ServerHello and populates tags +// References: +// - For TLS 1.3 supported_versions extension: RFC 8446 Section 4.2.1: https://tools.ietf.org/html/rfc8446#section-4.2.1 +// In the ClientHello, this extension contains a list of supported versions (2 bytes each) preceded by a 1-byte length. +// In the ServerHello (TLS 1.3), it contains a single selected_version (2 bytes). +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { + if (is_client_hello) { + // Read supported version list length (1 byte) + if (*offset + SINGLE_BYTE_LENGTH > data_end || *offset + SINGLE_BYTE_LENGTH > extensions_end) { + return false; + } + __u8 sv_list_length; + if (bpf_skb_load_bytes(skb, *offset, &sv_list_length, SINGLE_BYTE_LENGTH) < 0) { + return false; + } + *offset += SINGLE_BYTE_LENGTH; + + if (*offset + sv_list_length > data_end || *offset + sv_list_length > extensions_end) { + return false; + } + + // Parse the list of supported versions (2 bytes each) + __u8 sv_offset = 0; + __u16 sv_version; + #pragma unroll(MAX_SUPPORTED_VERSIONS) + for (int idx = 0; idx < MAX_SUPPORTED_VERSIONS; idx++) { + if (sv_offset + 1 >= sv_list_length) { + break; + } + // Each supported version is 2 bytes + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { + return false; + } + + if (bpf_skb_load_bytes(skb, *offset, &sv_version, PROTOCOL_VERSION_LENGTH) < 0) { + return false; + } + sv_version = bpf_ntohs(sv_version); + *offset += PROTOCOL_VERSION_LENGTH; + + set_tls_offered_version(tags, sv_version); + sv_offset += PROTOCOL_VERSION_LENGTH; + } + } else { + // ServerHello + // The selected_version field is 2 bytes + if (*offset + PROTOCOL_VERSION_LENGTH > data_end) { + return false; + } + + // Read selected version (2 bytes) + __u16 selected_version; + if (bpf_skb_load_bytes(skb, *offset, &selected_version, PROTOCOL_VERSION_LENGTH) < 0) { + return false; + } + selected_version = bpf_ntohs(selected_version); + *offset += PROTOCOL_VERSION_LENGTH; + + tags->chosen_version = selected_version; + } + + return true; +} + +// parse_tls_extensions parses TLS extensions in both ClientHello and ServerHello +// References: +// - RFC 5246 Section 7.4.1.4 (Hello Extensions): https://tools.ietf.org/html/rfc5246#section-7.4.1.4 +// - For TLS 1.3 supported_versions extension: RFC 8446 Section 4.2.1: https://tools.ietf.org/html/rfc8446#section-4.2.1 +// This function iterates over extensions, reading the extension_type and extension_length, and if it encounters +// the supported_versions extension, it calls parse_supported_versions_extension to handle it. +static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { + __u16 extension_type; + __u16 extension_length; + + #pragma unroll(MAX_EXTENSIONS) + for (int i = 0; i < MAX_EXTENSIONS; i++) { + if (*offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { + break; + } + + // Read Extension Type (2 bytes) + if (bpf_skb_load_bytes(skb, *offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { + return false; + } + extension_type = bpf_ntohs(extension_type); + *offset += EXTENSION_TYPE_LENGTH; + + // Read Extension Length (2 bytes) + if (bpf_skb_load_bytes(skb, *offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { + return false; + } + extension_length = bpf_ntohs(extension_length); + *offset += EXTENSION_LENGTH_FIELD; + + if (*offset + extension_length > data_end || *offset + extension_length > extensions_end) { + return false; + } + + if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { + if (!parse_supported_versions_extension(skb, offset, data_end, extensions_end, tags, is_client_hello)) { + return false; + } + } else { + // Skip other extensions + *offset += extension_length; + } + + if (*offset >= extensions_end) { + break; + } + } + + return true; +} + // parse_client_hello parses the ClientHello message and populates tags // Reference: RFC 5246 Section 7.4.1.2 (Client Hello), https://tools.ietf.org/html/rfc5246 // Structure (simplified): @@ -197,43 +343,17 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offs set_tls_offered_version(tags, client_version); - // If client_version < TLS 1.2, no extensions to parse + // TLS 1.2 is the highest version we will see in the header. If the connection is actually a higher version (1.3), + // it must be extracted from the extensions. Lower versions (1.0, 1.1) will not have extensions. if (client_version != TLS_VERSION12) { - // Skip Random (32 bytes) - offset += RANDOM_LENGTH; - - // Session ID Length (1 byte) - if (offset + SESSION_ID_LENGTH > data_end) { - return false; - } - __u8 session_id_length; - if (bpf_skb_load_bytes(skb, offset, &session_id_length, SESSION_ID_LENGTH) < 0) { - return false; - } - offset += SESSION_ID_LENGTH; - - // Skip Session ID - offset += session_id_length; return true; } - // TLS 1.2 case: - // Skip Random (32 bytes) - offset += RANDOM_LENGTH; - - // Session ID Length (1 byte) - if (offset + SESSION_ID_LENGTH > data_end) { - return false; - } - __u8 session_id_length; - if (bpf_skb_load_bytes(skb, offset, &session_id_length, SESSION_ID_LENGTH) < 0) { + if (!skip_random_and_session_id(skb, &offset, data_end)) { return false; } - offset += SESSION_ID_LENGTH; - - // Skip Session ID - offset += session_id_length; + // Read Cipher Suites Length (2 bytes) if (offset + CIPHER_SUITES_LENGTH > data_end) { return false; } @@ -247,7 +367,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offs // Skip Cipher Suites offset += cipher_suites_length; - // Compression Methods Length (1 byte) + // Read Compression Methods Length (1 byte) if (offset + COMPRESSION_METHODS_LENGTH > data_end) { return false; } @@ -260,10 +380,12 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offs // Skip Compression Methods offset += compression_methods_length; - // Extensions Length (2 bytes) + // Check if extensions are present if (offset + EXTENSION_LENGTH_FIELD > data_end) { return false; } + + // Read Extensions Length (2 bytes) __u16 extensions_length; if (bpf_skb_load_bytes(skb, offset, &extensions_length, EXTENSION_LENGTH_FIELD) < 0) { return false; @@ -277,78 +399,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offs __u64 extensions_end = offset + extensions_length; - // Inline extension parsing: - __u16 extension_type; - __u16 extension_length; - #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { - break; - } - - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { - return false; - } - extension_type = bpf_ntohs(extension_type); - offset += EXTENSION_TYPE_LENGTH; - - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { - return false; - } - extension_length = bpf_ntohs(extension_length); - offset += EXTENSION_LENGTH_FIELD; - - if (offset + extension_length > data_end || offset + extension_length > extensions_end) { - return false; - } - - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - if (offset + SINGLE_BYTE_LENGTH > data_end || offset + SINGLE_BYTE_LENGTH > extensions_end) { - return false; - } - __u8 sv_list_length; - if (bpf_skb_load_bytes(skb, offset, &sv_list_length, SINGLE_BYTE_LENGTH) < 0) { - return false; - } - offset += SINGLE_BYTE_LENGTH; - - if (offset + sv_list_length > data_end || offset + sv_list_length > extensions_end) { - return false; - } - - __u8 sv_offset = 0; - __u16 sv_version; - #pragma unroll(MAX_SUPPORTED_VERSIONS) - for (int j = 0; j < MAX_SUPPORTED_VERSIONS; j++) { - if (sv_offset + 1 >= sv_list_length) { - break; - } - if (offset + PROTOCOL_VERSION_LENGTH > data_end) { - return false; - } - - if (bpf_skb_load_bytes(skb, offset, &sv_version, PROTOCOL_VERSION_LENGTH) < 0) { - return false; - } - sv_version = bpf_ntohs(sv_version); - offset += PROTOCOL_VERSION_LENGTH; - - set_tls_offered_version(tags, sv_version); - sv_offset += PROTOCOL_VERSION_LENGTH; - } - } else { - // Skip other extensions - offset += extension_length; - } - - if (offset >= extensions_end) { - break; - } - } - - return true; + return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, true); } // parse_server_hello parses the ServerHello message and populates tags @@ -368,17 +419,9 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u32 offs // The actual version is embedded in the supported_versions extension tags->chosen_version = server_version; - offset += RANDOM_LENGTH; // Skip Random - - if (offset + SESSION_ID_LENGTH > data_end) { + if (!skip_random_and_session_id(skb, &offset, data_end)) { return false; } - __u8 session_id_length; - if (bpf_skb_load_bytes(skb, offset, &session_id_length, SESSION_ID_LENGTH) < 0) { - return false; - } - offset += SESSION_ID_LENGTH; - offset += session_id_length; // Skip Session ID // Read Cipher Suite (2 bytes) if (offset + CIPHER_SUITES_LENGTH > data_end) { @@ -391,7 +434,8 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u32 offs cipher_suite = bpf_ntohs(cipher_suite); offset += CIPHER_SUITES_LENGTH; - offset += COMPRESSION_METHODS_LENGTH; // Skip Compression Method + // Skip Compression Method (1 byte) + offset += COMPRESSION_METHODS_LENGTH; tags->cipher_suite = cipher_suite; @@ -420,65 +464,14 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u32 offs __u64 extensions_end = offset + extensions_length; - __u16 extension_type; - __u16 extension_length; - #pragma unroll(MAX_EXTENSIONS) - for (int i = 0; i < MAX_EXTENSIONS; i++) { - if (offset + MIN_EXTENSION_HEADER_LENGTH > extensions_end) { - break; - } - - // Read Extension Type (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_type, EXTENSION_TYPE_LENGTH) < 0) { - return false; - } - extension_type = bpf_ntohs(extension_type); - offset += EXTENSION_TYPE_LENGTH; - - // Read Extension Length (2 bytes) - if (bpf_skb_load_bytes(skb, offset, &extension_length, EXTENSION_LENGTH_FIELD) < 0) { - return false; - } - extension_length = bpf_ntohs(extension_length); - offset += EXTENSION_LENGTH_FIELD; - - if (offset + extension_length > data_end || offset + extension_length > extensions_end) { - return false; - } - - if (extension_type == SUPPORTED_VERSIONS_EXTENSION) { - // Inline parse_supported_versions_extension for ServerHello - if (offset + PROTOCOL_VERSION_LENGTH > data_end) { - return false; - } - - __u16 selected_version; - if (bpf_skb_load_bytes(skb, offset, &selected_version, PROTOCOL_VERSION_LENGTH) < 0) { - return false; - } - selected_version = bpf_ntohs(selected_version); - offset += PROTOCOL_VERSION_LENGTH; - - tags->chosen_version = selected_version; - } else { - // Skip other extensions - offset += extension_length; - } - - if (offset >= extensions_end) { - break; - } - } - - return true; + return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, false); } -// is_tls_handshake_type checks if the handshake type is the expected type (client or server hello) -static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u8 content_type, __u32 offset, __u32 data_end, __u8 expected_handshake_type) { - if (content_type != TLS_HANDSHAKE) { - return false; - } - +// is_tls_handshake_type checks if the handshake type at the given offset matches the expected type (e.g., ClientHello or ServerHello) +// References: +// - RFC 5246 Section 7.4 (Handshake Protocol Overview), https://tools.ietf.org/html/rfc5246#section-7.4 +// The handshake_type is a single byte enumerated value. +static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u32 offset, __u32 data_end, __u8 expected_handshake_type) { // The handshake type is a single byte enumerated value if (offset + SINGLE_BYTE_LENGTH > data_end) { return false; @@ -492,13 +485,13 @@ static __always_inline bool is_tls_handshake_type(struct __sk_buff *skb, __u8 co } // is_tls_handshake_client_hello checks if the packet is a TLS ClientHello message -static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, __u8 content_type, __u32 offset, __u32 data_end) { - return is_tls_handshake_type(skb, content_type, offset, data_end, TLS_HANDSHAKE_CLIENT_HELLO); +static __always_inline bool is_tls_handshake_client_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end) { + return is_tls_handshake_type(skb, offset, data_end, TLS_HANDSHAKE_CLIENT_HELLO); } // is_tls_handshake_server_hello checks if the packet is a TLS ServerHello message -static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, __u8 content_type, __u32 offset, __u32 data_end) { - return is_tls_handshake_type(skb, content_type, offset, data_end, TLS_HANDSHAKE_SERVER_HELLO); +static __always_inline bool is_tls_handshake_server_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end) { + return is_tls_handshake_type(skb, offset, data_end, TLS_HANDSHAKE_SERVER_HELLO); } #endif // __TLS_H diff --git a/pkg/network/ebpf/c/tracer/maps.h b/pkg/network/ebpf/c/tracer/maps.h index e6123782f8ea56..0141a970d1e5eb 100644 --- a/pkg/network/ebpf/c/tracer/maps.h +++ b/pkg/network/ebpf/c/tracer/maps.h @@ -132,4 +132,7 @@ BPF_HASH_MAP(tcp_close_args, __u64, conn_tuple_t, 1024) // by using tail call. BPF_PROG_ARRAY(tcp_close_progs, 1) +// Map to store extra information about TLS connections like version, cipher, etc. +BPF_HASH_MAP(tls_enhanced_tags, conn_tuple_t, tls_info_wrapper_t, 0) + #endif diff --git a/pkg/network/ebpf/c/tracer/stats.h b/pkg/network/ebpf/c/tracer/stats.h index 5698d88b91f763..ce47839bba79ea 100644 --- a/pkg/network/ebpf/c/tracer/stats.h +++ b/pkg/network/ebpf/c/tracer/stats.h @@ -22,6 +22,35 @@ static __always_inline __u64 offset_rtt(); static __always_inline __u64 offset_rtt_var(); #endif +static __always_inline tls_info_t* get_tls_enhanced_tags(conn_tuple_t* tuple) { + conn_tuple_t normalized_tup = *tuple; + normalize_tuple(&normalized_tup); + tls_info_wrapper_t *wrapper = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + if (!wrapper) { + return NULL; + } + wrapper->updated = bpf_ktime_get_ns(); + return &wrapper->info; +} + +static __always_inline tls_info_t* get_or_create_tls_enhanced_tags(conn_tuple_t *tuple) { + tls_info_t *tags = get_tls_enhanced_tags(tuple); + if (!tags) { + conn_tuple_t normalized_tup = *tuple; + normalize_tuple(&normalized_tup); + tls_info_wrapper_t empty_tags_wrapper = {}; + empty_tags_wrapper.updated = bpf_ktime_get_ns(); + + bpf_map_update_with_telemetry(tls_enhanced_tags, &normalized_tup, &empty_tags_wrapper, BPF_ANY); + tls_info_wrapper_t *wrapper_ptr = bpf_map_lookup_elem(&tls_enhanced_tags, &normalized_tup); + if (!wrapper_ptr) { + return NULL; + } + tags = &wrapper_ptr->info; + } + return tags; +} + // merge_tls_info modifies `this` by merging it with `that` static __always_inline void merge_tls_info(tls_info_t *this, tls_info_t *that) { if (!this || !that) { diff --git a/pkg/network/ebpf/kprobe_types_linux.go b/pkg/network/ebpf/kprobe_types_linux.go index 5e413cb1964f19..2953275bd5bfb0 100644 --- a/pkg/network/ebpf/kprobe_types_linux.go +++ b/pkg/network/ebpf/kprobe_types_linux.go @@ -145,9 +145,9 @@ const SizeofConn = 0x78 type ClassificationProgram = uint32 const ( - ClassificationTLSClient ClassificationProgram = 0x2 - ClassificationTLSServer ClassificationProgram = 0x3 - ClassificationQueues ClassificationProgram = 0x5 - ClassificationDBs ClassificationProgram = 0x6 - ClassificationGRPC ClassificationProgram = 0x8 + ClassificationTLSClient ClassificationProgram = 0x7 + ClassificationTLSServer ClassificationProgram = 0x8 + ClassificationQueues ClassificationProgram = 0x2 + ClassificationDBs ClassificationProgram = 0x3 + ClassificationGRPC ClassificationProgram = 0x5 ) From c52e48e4968bab5c0b0e9df2e15522ec396808ab Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Thu, 19 Dec 2024 17:44:19 -0500 Subject: [PATCH 51/53] fix typos --- .../c/protocols/classification/protocol-classification.h | 7 +++---- .../ebpf/c/protocols/classification/routing-helpers.h | 6 ++++-- .../ebpf/c/protocols/classification/shared-tracer-maps.h | 3 --- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h index 4c62c543350ca6..865288731ac7f4 100644 --- a/pkg/network/ebpf/c/protocols/classification/protocol-classification.h +++ b/pkg/network/ebpf/c/protocols/classification/protocol-classification.h @@ -177,8 +177,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint(struct // Parse TLS handshake payload tls_info_t *tags = get_or_create_tls_enhanced_tags(&usm_ctx->tuple); if (tags) { - // The packet is a TLS handshake, so trigger some tail calls - // to extract metadata from the payload + // The packet is a TLS handshake, so trigger tail calls to extract metadata from the payload goto next_program; } return; @@ -221,7 +220,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!tls_info) { goto next_program; } - __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); + __u32 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; if (!is_tls_handshake_client_hello(skb, offset, usm_ctx->skb_info.data_end)) { goto next_program; @@ -243,7 +242,7 @@ __maybe_unused static __always_inline void protocol_classifier_entrypoint_tls_ha if (!tls_info) { goto next_program; } - __u64 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); + __u32 offset = usm_ctx->skb_info.data_off + sizeof(tls_record_header_t); __u32 data_end = usm_ctx->skb_info.data_end; if (!is_tls_handshake_server_hello(skb, offset, data_end)) { goto next_program; diff --git a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h index 944a8b0f049342..d7565977eb36d5 100644 --- a/pkg/network/ebpf/c/protocols/classification/routing-helpers.h +++ b/pkg/network/ebpf/c/protocols/classification/routing-helpers.h @@ -23,11 +23,13 @@ static __always_inline u16 get_current_program_layer(classification_prog_t curre if (current_program > __PROG_APPLICATION && current_program < __PROG_API) { return LAYER_APPLICATION_BIT; } + if (current_program > __PROG_API && current_program < __PROG_ENCRYPTION) { - return LAYER_ENCRYPTION_BIT; + return LAYER_API_BIT; } + if (current_program > __PROG_ENCRYPTION && current_program < CLASSIFICATION_PROG_MAX) { - return LAYER_API_BIT; + return LAYER_ENCRYPTION_BIT; } return 0; diff --git a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h index 7ca913c722f3a7..d3e3c4c73a558f 100644 --- a/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h +++ b/pkg/network/ebpf/c/protocols/classification/shared-tracer-maps.h @@ -4,7 +4,6 @@ #include "map-defs.h" #include "port_range.h" #include "protocols/classification/stack-helpers.h" -#include "protocols/tls/tls.h" // Maps a connection tuple to its classified protocol. Used to reduce redundant // classification procedures on the same connection @@ -146,6 +145,4 @@ __maybe_unused static __always_inline void delete_protocol_stack(conn_tuple_t* n bpf_map_delete_elem(&connection_protocol, normalized_tuple); } - - #endif From fd3800b7bcfff9d218aa6b50003e8f59b4d47bf9 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Fri, 20 Dec 2024 13:01:03 -0500 Subject: [PATCH 52/53] even stricter validation of tls handshake classification, ascii art for docs --- pkg/network/ebpf/c/protocols/tls/tls.h | 151 ++++++++++++++++++++++--- 1 file changed, 138 insertions(+), 13 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index a056794c78bb44..c1587d5e8d7da2 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -45,14 +45,15 @@ // RFC 8446 - The Transport Layer Security (TLS) Protocol Version 1.3 // https://tools.ietf.org/html/rfc8446 // Many handshake structures are similar, but some extensions (like supported_versions) are defined here -#define TLS_HANDSHAKE_LENGTH 3 // Handshake length is 3 bytes (RFC 5246 Section 7.4) -#define RANDOM_LENGTH 32 // Random field length in Client/Server Hello (RFC 5246 Section 7.4.1.2) -#define PROTOCOL_VERSION_LENGTH 2 // Protocol version field is 2 bytes (RFC 5246 Section 6.2.1) -#define SESSION_ID_LENGTH 1 // Session ID length field is 1 byte (RFC 5246 Section 7.4.1.2) -#define CIPHER_SUITES_LENGTH 2 // Cipher Suites length field is 2 bytes (RFC 5246 Section 7.4.1.2) -#define COMPRESSION_METHODS_LENGTH 1 // Compression Methods length field is 1 byte (RFC 5246 Section 7.4.1.2) -#define EXTENSION_TYPE_LENGTH 2 // Extension Type field is 2 bytes (RFC 5246 Section 7.4.1.4) -#define EXTENSION_LENGTH_FIELD 2 // Extension Length field is 2 bytes (RFC 5246 Section 7.4.1.4) +#define TLS_HANDSHAKE_LENGTH 3 // Handshake length is 3 bytes (RFC 5246 Section 7.4) +#define TLS_HELLO_MESSAGE_HEADER_SIZE 4 // handshake_type(1) + length(3) +#define RANDOM_LENGTH 32 // Random field length in Client/Server Hello (RFC 5246 Section 7.4.1.2) +#define PROTOCOL_VERSION_LENGTH 2 // Protocol version field is 2 bytes (RFC 5246 Section 6.2.1) +#define SESSION_ID_LENGTH 1 // Session ID length field is 1 byte (RFC 5246 Section 7.4.1.2) +#define CIPHER_SUITES_LENGTH 2 // Cipher Suites length field is 2 bytes (RFC 5246 Section 7.4.1.2) +#define COMPRESSION_METHODS_LENGTH 1 // Compression Methods length field is 1 byte (RFC 5246 Section 7.4.1.2) +#define EXTENSION_TYPE_LENGTH 2 // Extension Type field is 2 bytes (RFC 5246 Section 7.4.1.4) +#define EXTENSION_LENGTH_FIELD 2 // Extension Length field is 2 bytes (RFC 5246 Section 7.4.1.4) // For single-byte fields (list lengths, etc.) #define SINGLE_BYTE_LENGTH 1 @@ -105,6 +106,15 @@ static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 } } +// TLS Record Header (RFC 5246 Section 6.2.1) +// +// +---------+---------+---------+-----------+ +// | type(1) | version(2) | length(2) | +// +---------+---------+---------+-----------+ +// type: 1 byte (TLS_CONTENT_TYPE) +// version: 2 bytes (e.g., 0x03 0x03 for TLS 1.2) +// length: 2 bytes (total number of payload bytes following this header) + // read_tls_record_header reads the TLS record header from the packet // Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 // Validates the record header fields (content_type, version, length) and checks for correctness within packet bounds. @@ -135,6 +145,68 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 return header_offset + sizeof(tls_record_header_t) + tls_hdr->length <= data_end; } +// TLS Handshake Message Header (RFC 5246 Section 7.4) +// +---------+---------+---------+---------+ +// | handshake_type(1) | length(3 bytes) | +// +---------+---------+---------+---------+ +// +// The handshake_type identifies the handshake message (e.g., ClientHello, ServerHello). +// length indicates the size of the handshake message that follows (not including these 4 bytes). + +// is_valid_tls_handshake checks if the TLS handshake message is valid +// The function expects the record to have already been validated. It further checks that the +// handshake_type and handshake_length are consistent. +static __always_inline bool is_valid_tls_handshake(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, const tls_record_header_t *hdr) { + // At this point, we know from read_tls_record_header() that: + // - hdr->version is a valid TLS version + // - hdr->length fits entirely within the packet (header_offset + hdr->length <= data_end) + + __u64 handshake_offset = header_offset + sizeof(tls_record_header_t); + + // Ensure we don't read beyond the packet + if (handshake_offset + SINGLE_BYTE_LENGTH > data_end) { + return false; + } + // Read handshake_type (1 byte) + __u8 handshake_type; + if (bpf_skb_load_bytes(skb, handshake_offset, &handshake_type, SINGLE_BYTE_LENGTH) < 0) { + return false; + } + + // Read handshake_length (3 bytes) + __u64 length_offset = handshake_offset + SINGLE_BYTE_LENGTH; + if (length_offset + TLS_HANDSHAKE_LENGTH > data_end) { + return false; + } + __u8 handshake_length_bytes[TLS_HANDSHAKE_LENGTH]; + if (bpf_skb_load_bytes(skb, length_offset, handshake_length_bytes, TLS_HANDSHAKE_LENGTH) < 0) { + return false; + } + + __u32 handshake_length = (handshake_length_bytes[0] << 16) | + (handshake_length_bytes[1] << 8) | + handshake_length_bytes[2]; + + // Verify that the handshake message length plus the 4-byte handshake header (1 byte type + 3 bytes length) + // matches the total length defined in the record header. + // If handshake_length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length, the handshake message structure is inconsistent. + if (handshake_length + TLS_HELLO_MESSAGE_HEADER_SIZE != hdr->length) { + return false; + } + + // Check that the handshake_type is one of the expected values (ClientHello or ServerHello). + // This ensures we are dealing with a known handshake message type. + if (handshake_type != TLS_HANDSHAKE_CLIENT_HELLO && handshake_type != TLS_HANDSHAKE_SERVER_HELLO) { + return false; + } + + // At this point, we've confirmed: + // - The handshake message fits within the record. + // - The handshake_type is a known TLS Hello message. + // - The handshake_length matches the record header's length. + return true; +} + // is_tls checks if the packet is a TLS packet by reading and validating the TLS record header // Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 // Validates that content_type matches known TLS types (Handshake, Application Data, etc.). @@ -144,9 +216,16 @@ static __always_inline bool is_tls(struct __sk_buff *skb, __u64 header_offset, _ return false; } - // Validate content type - __u8 ct = tls_hdr->content_type; - return ct == TLS_HANDSHAKE || ct == TLS_APPLICATION_DATA || ct == TLS_CHANGE_CIPHER_SPEC || ct == TLS_ALERT; + switch (tls_hdr->content_type) { + case TLS_HANDSHAKE: + return is_valid_tls_handshake(skb, header_offset, data_end, tls_hdr); + case TLS_APPLICATION_DATA: + case TLS_CHANGE_CIPHER_SPEC: + case TLS_ALERT: + return true; + default: + return false; + } } // parse_tls_handshake_header extracts handshake_length and protocol_version from a TLS handshake message @@ -217,8 +296,15 @@ static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __ // parse_supported_versions_extension looks for the supported_versions extension in the ClientHello or ServerHello and populates tags // References: // - For TLS 1.3 supported_versions extension: RFC 8446 Section 4.2.1: https://tools.ietf.org/html/rfc8446#section-4.2.1 -// In the ClientHello, this extension contains a list of supported versions (2 bytes each) preceded by a 1-byte length. -// In the ServerHello (TLS 1.3), it contains a single selected_version (2 bytes). +// For ClientHello this extension contains a list of supported versions (2 bytes each) preceded by a 1-byte length. +// supported_versions extension structure: +// +-----+--------------------+ +// | len(1) | versions(2 * N) | +// +-----+--------------------+ +// For ServerHello (TLS 1.3), it contains a single selected_version (2 bytes). +// +---------------------+ +// | selected_version(2) | +// +---------------------+ static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { if (is_client_hello) { // Read supported version list length (1 byte) @@ -284,6 +370,11 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff // - For TLS 1.3 supported_versions extension: RFC 8446 Section 4.2.1: https://tools.ietf.org/html/rfc8446#section-4.2.1 // This function iterates over extensions, reading the extension_type and extension_length, and if it encounters // the supported_versions extension, it calls parse_supported_versions_extension to handle it. +// ASCII snippet for a single extension: +// +---------+---------+--------------------------------+ +// | ext_type(2) | ext_length(2) | ext_data(ext_length) | +// +---------+---------+--------------------------------+ +// For multiple extensions, they are just concatenated one after another. static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { __u16 extension_type; __u16 extension_length; @@ -333,6 +424,24 @@ static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u32 *o // Reference: RFC 5246 Section 7.4.1.2 (Client Hello), https://tools.ietf.org/html/rfc5246 // Structure (simplified): // handshake_type (1 byte), length (3 bytes), version (2 bytes), random(32 bytes), session_id_length(1 byte), session_id(variable), cipher_suites_length(2 bytes), cipher_suites(variable), compression_methods_length(1 byte), compression_methods(variable), extensions_length(2 bytes), extensions(variable) +// After the handshake header (handshake_type + length), the ClientHello fields are: +// +----------------------------+ +// | client_version (2) | +// +----------------------------+ +// | random (32) | +// +----------------------------+ +// | session_id_length (1) | +// | session_id (...) | +// +----------------------------+ +// | cipher_suites_length(2) | +// | cipher_suites(...) | +// +----------------------------+ +// | compression_methods_len(1) | +// | compression_methods(...) | +// +----------------------------+ +// | extensions_length (2) | +// | extensions(...) | +// +----------------------------+ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end, tls_info_t *tags) { __u32 handshake_length; __u16 client_version; @@ -406,6 +515,22 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offs // Reference: RFC 5246 Section 7.4.1.2 (Server Hello), https://tools.ietf.org/html/rfc5246 // Structure (simplified): // handshake_type(1), length(3), version(2), random(32), session_id_length(1), session_id(variable), cipher_suite(2), compression_method(1), extensions_length(2), extensions(variable) +// After the handshake header (handshake_type + length), the ServerHello fields are: +// +------------------------+ +// | server_version (2) | +// +------------------------+ +// | random (32) | +// +------------------------+ +// | session_id_length (1) | +// | session_id (...) | +// +------------------------+ +// | cipher_suite (2) | +// +------------------------+ +// | compression_method (1) | +// +------------------------+ +// | extensions_length(2) | +// | extensions(...) | +// +------------------------+ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u32 offset, __u32 data_end, tls_info_t *tags) { __u32 handshake_length; __u16 server_version; From e54754254559428cb4cd9bbf8f9e76d6ccdd2564 Mon Sep 17 00:00:00 2001 From: Adam Karpowich Date: Mon, 23 Dec 2024 10:15:25 -0500 Subject: [PATCH 53/53] use u32s in tls parsing, fix tests --- pkg/network/ebpf/c/protocols/tls/tls.h | 20 ++++++++++---------- pkg/network/tracer/tracer_linux_test.go | 18 ++++++------------ 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/pkg/network/ebpf/c/protocols/tls/tls.h b/pkg/network/ebpf/c/protocols/tls/tls.h index c1587d5e8d7da2..813c548360a9af 100644 --- a/pkg/network/ebpf/c/protocols/tls/tls.h +++ b/pkg/network/ebpf/c/protocols/tls/tls.h @@ -118,7 +118,7 @@ static __always_inline void set_tls_offered_version(tls_info_t *tls_info, __u16 // read_tls_record_header reads the TLS record header from the packet // Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 // Validates the record header fields (content_type, version, length) and checks for correctness within packet bounds. -static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { +static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u32 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { // Ensure there's enough space for TLS record header if (header_offset + sizeof(tls_record_header_t) > data_end) { return false; @@ -156,12 +156,12 @@ static __always_inline bool read_tls_record_header(struct __sk_buff *skb, __u64 // is_valid_tls_handshake checks if the TLS handshake message is valid // The function expects the record to have already been validated. It further checks that the // handshake_type and handshake_length are consistent. -static __always_inline bool is_valid_tls_handshake(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, const tls_record_header_t *hdr) { +static __always_inline bool is_valid_tls_handshake(struct __sk_buff *skb, __u32 header_offset, __u32 data_end, const tls_record_header_t *hdr) { // At this point, we know from read_tls_record_header() that: // - hdr->version is a valid TLS version // - hdr->length fits entirely within the packet (header_offset + hdr->length <= data_end) - __u64 handshake_offset = header_offset + sizeof(tls_record_header_t); + __u32 handshake_offset = header_offset + sizeof(tls_record_header_t); // Ensure we don't read beyond the packet if (handshake_offset + SINGLE_BYTE_LENGTH > data_end) { @@ -174,7 +174,7 @@ static __always_inline bool is_valid_tls_handshake(struct __sk_buff *skb, __u64 } // Read handshake_length (3 bytes) - __u64 length_offset = handshake_offset + SINGLE_BYTE_LENGTH; + __u32 length_offset = handshake_offset + SINGLE_BYTE_LENGTH; if (length_offset + TLS_HANDSHAKE_LENGTH > data_end) { return false; } @@ -210,7 +210,7 @@ static __always_inline bool is_valid_tls_handshake(struct __sk_buff *skb, __u64 // is_tls checks if the packet is a TLS packet by reading and validating the TLS record header // Reference: RFC 5246 Section 6.2.1 (Record Layer), https://tools.ietf.org/html/rfc5246#section-6.2.1 // Validates that content_type matches known TLS types (Handshake, Application Data, etc.). -static __always_inline bool is_tls(struct __sk_buff *skb, __u64 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { +static __always_inline bool is_tls(struct __sk_buff *skb, __u32 header_offset, __u32 data_end, tls_record_header_t *tls_hdr) { // Read and validate the TLS record header if (!read_tls_record_header(skb, header_offset, data_end, tls_hdr)) { return false; @@ -305,7 +305,7 @@ static __always_inline bool skip_random_and_session_id(struct __sk_buff *skb, __ // +---------------------+ // | selected_version(2) | // +---------------------+ -static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { +static __always_inline bool parse_supported_versions_extension(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u32 extensions_end, tls_info_t *tags, bool is_client_hello) { if (is_client_hello) { // Read supported version list length (1 byte) if (*offset + SINGLE_BYTE_LENGTH > data_end || *offset + SINGLE_BYTE_LENGTH > extensions_end) { @@ -375,7 +375,7 @@ static __always_inline bool parse_supported_versions_extension(struct __sk_buff // | ext_type(2) | ext_length(2) | ext_data(ext_length) | // +---------+---------+--------------------------------+ // For multiple extensions, they are just concatenated one after another. -static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u64 extensions_end, tls_info_t *tags, bool is_client_hello) { +static __always_inline bool parse_tls_extensions(struct __sk_buff *skb, __u32 *offset, __u32 data_end, __u32 extensions_end, tls_info_t *tags, bool is_client_hello) { __u16 extension_type; __u16 extension_length; @@ -506,7 +506,7 @@ static __always_inline bool parse_client_hello(struct __sk_buff *skb, __u32 offs return false; } - __u64 extensions_end = offset + extensions_length; + __u32 extensions_end = offset + extensions_length; return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, true); } @@ -582,12 +582,12 @@ static __always_inline bool parse_server_hello(struct __sk_buff *skb, __u32 offs extensions_length = bpf_ntohs(extensions_length); offset += EXTENSION_LENGTH_FIELD; - __u64 handshake_end = offset + handshake_length; + __u32 handshake_end = offset + handshake_length; if (offset + extensions_length > data_end || offset + extensions_length > handshake_end) { return false; } - __u64 extensions_end = offset + extensions_length; + __u32 extensions_end = offset + extensions_length; return parse_tls_extensions(skb, &offset, data_end, extensions_end, tags, false); } diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index c683216e06eacc..3825b9d1bef075 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -2654,8 +2654,8 @@ func (s *TracerSuite) TestTLSClassification() { return port, scenario }, validation: func(t *testing.T, tr *Tracer, port uint16, scenario uint16) { - require.Eventuallyf(t, func() bool { - return validateTLSTags(t, tr, port, scenario) + require.EventuallyWithT(t, func(ct *assert.CollectT) { + require.True(ct, validateTLSTags(ct, tr, port, scenario), "TLS tags not set") }, 3*time.Second, 100*time.Millisecond, "couldn't find TLS connection matching: dst port %v", port) }, }) @@ -2706,15 +2706,14 @@ func (s *TracerSuite) TestTLSClassification() { }, validation: func(t *testing.T, tr *Tracer, port uint16, _ uint16) { // Verify that no TLS tags are set for this connection - require.Eventually(t, func() bool { - payload := getConnections(t, tr) + require.EventuallyWithT(t, func(ct *assert.CollectT) { + payload := getConnections(ct, tr) for _, c := range payload.Conns { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) { t.Log("Unexpected TLS protocol detected for invalid handshake") - return false + require.Fail(ct, "unexpected TLS tags") } } - return true }, 3*time.Second, 100*time.Millisecond) }, }) @@ -2738,7 +2737,7 @@ func (s *TracerSuite) TestTLSClassification() { } } -func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) bool { +func validateTLSTags(t *assert.CollectT, tr *Tracer, port uint16, scenario uint16) bool { payload := getConnections(t, tr) for _, c := range payload.Conns { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { @@ -2753,22 +2752,18 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo } } if !cipherSuiteTagFound { - t.Log("Cipher suite ID tag missing") return false } // Check that the negotiated version tag is present negotiatedVersionTag := ddtls.VersionTags[scenario] if _, ok := tlsTags[negotiatedVersionTag]; !ok { - t.Logf("Negotiated version tag '%s' not found", negotiatedVersionTag) return false } // Check that the client offered version tag is present clientVersionTag := ddtls.ClientVersionTags[scenario] if _, ok := tlsTags[clientVersionTag]; !ok { - t.Log(tlsTags) - t.Logf("Client offered version tag '%s' not found", clientVersionTag) return false } @@ -2779,7 +2774,6 @@ func validateTLSTags(t *testing.T, tr *Tracer, port uint16, scenario uint16) boo } for _, tag := range expectedClientVersions { if _, ok := tlsTags[tag]; !ok { - t.Logf("Expected client offered version tag '%s' not found", tag) return false } }