diff --git a/net/rxrpc/conn_event.c b/net/rxrpc/conn_event.c index 4ca11be6be3cad..b1dfae10743102 100644 --- a/net/rxrpc/conn_event.c +++ b/net/rxrpc/conn_event.c @@ -460,6 +460,7 @@ void rxrpc_process_connection(struct work_struct *work) case -EKEYEXPIRED: case -EKEYREJECTED: goto protocol_error; + case -ENOMEM: case -EAGAIN: goto requeue_and_leave; case -ECONNABORTED: diff --git a/net/rxrpc/rxkad.c b/net/rxrpc/rxkad.c index c38b3a1de56c13..77cb23c7bd0a84 100644 --- a/net/rxrpc/rxkad.c +++ b/net/rxrpc/rxkad.c @@ -773,8 +773,7 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, { const struct rxrpc_key_token *token; struct rxkad_challenge challenge; - struct rxkad_response resp - __attribute__((aligned(8))); /* must be aligned for crypto */ + struct rxkad_response *resp; struct rxrpc_skb_priv *sp = rxrpc_skb(skb); const char *eproto; u32 version, nonce, min_level, abort_code; @@ -818,26 +817,29 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, token = conn->params.key->payload.data[0]; /* build the response packet */ - memset(&resp, 0, sizeof(resp)); - - resp.version = htonl(RXKAD_VERSION); - resp.encrypted.epoch = htonl(conn->proto.epoch); - resp.encrypted.cid = htonl(conn->proto.cid); - resp.encrypted.securityIndex = htonl(conn->security_ix); - resp.encrypted.inc_nonce = htonl(nonce + 1); - resp.encrypted.level = htonl(conn->params.security_level); - resp.kvno = htonl(token->kad->kvno); - resp.ticket_len = htonl(token->kad->ticket_len); - - resp.encrypted.call_id[0] = htonl(conn->channels[0].call_counter); - resp.encrypted.call_id[1] = htonl(conn->channels[1].call_counter); - resp.encrypted.call_id[2] = htonl(conn->channels[2].call_counter); - resp.encrypted.call_id[3] = htonl(conn->channels[3].call_counter); + resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS); + if (!resp) + return -ENOMEM; + + resp->version = htonl(RXKAD_VERSION); + resp->encrypted.epoch = htonl(conn->proto.epoch); + resp->encrypted.cid = htonl(conn->proto.cid); + resp->encrypted.securityIndex = htonl(conn->security_ix); + resp->encrypted.inc_nonce = htonl(nonce + 1); + resp->encrypted.level = htonl(conn->params.security_level); + resp->kvno = htonl(token->kad->kvno); + resp->ticket_len = htonl(token->kad->ticket_len); + resp->encrypted.call_id[0] = htonl(conn->channels[0].call_counter); + resp->encrypted.call_id[1] = htonl(conn->channels[1].call_counter); + resp->encrypted.call_id[2] = htonl(conn->channels[2].call_counter); + resp->encrypted.call_id[3] = htonl(conn->channels[3].call_counter); /* calculate the response checksum and then do the encryption */ - rxkad_calc_response_checksum(&resp); - rxkad_encrypt_response(conn, &resp, token->kad); - return rxkad_send_response(conn, &sp->hdr, &resp, token->kad); + rxkad_calc_response_checksum(resp); + rxkad_encrypt_response(conn, resp, token->kad); + ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad); + kfree(resp); + return ret; protocol_error: trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); @@ -1048,8 +1050,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, struct sk_buff *skb, u32 *_abort_code) { - struct rxkad_response response - __attribute__((aligned(8))); /* must be aligned for crypto */ + struct rxkad_response *response; struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_crypt session_key; const char *eproto; @@ -1061,17 +1062,22 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key)); + ret = -ENOMEM; + response = kzalloc(sizeof(struct rxkad_response), GFP_NOFS); + if (!response) + goto temporary_error; + eproto = tracepoint_string("rxkad_rsp_short"); abort_code = RXKADPACKETSHORT; if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), - &response, sizeof(response)) < 0) + response, sizeof(*response)) < 0) goto protocol_error; - if (!pskb_pull(skb, sizeof(response))) + if (!pskb_pull(skb, sizeof(*response))) BUG(); - version = ntohl(response.version); - ticket_len = ntohl(response.ticket_len); - kvno = ntohl(response.kvno); + version = ntohl(response->version); + ticket_len = ntohl(response->ticket_len); + kvno = ntohl(response->kvno); _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }", sp->hdr.serial, version, kvno, ticket_len); @@ -1105,31 +1111,31 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ret = rxkad_decrypt_ticket(conn, skb, ticket, ticket_len, &session_key, &expiry, _abort_code); if (ret < 0) - goto temporary_error_free; + goto temporary_error_free_resp; /* use the session key from inside the ticket to decrypt the * response */ - rxkad_decrypt_response(conn, &response, &session_key); + rxkad_decrypt_response(conn, response, &session_key); eproto = tracepoint_string("rxkad_rsp_param"); abort_code = RXKADSEALEDINCON; - if (ntohl(response.encrypted.epoch) != conn->proto.epoch) + if (ntohl(response->encrypted.epoch) != conn->proto.epoch) goto protocol_error_free; - if (ntohl(response.encrypted.cid) != conn->proto.cid) + if (ntohl(response->encrypted.cid) != conn->proto.cid) goto protocol_error_free; - if (ntohl(response.encrypted.securityIndex) != conn->security_ix) + if (ntohl(response->encrypted.securityIndex) != conn->security_ix) goto protocol_error_free; - csum = response.encrypted.checksum; - response.encrypted.checksum = 0; - rxkad_calc_response_checksum(&response); + csum = response->encrypted.checksum; + response->encrypted.checksum = 0; + rxkad_calc_response_checksum(response); eproto = tracepoint_string("rxkad_rsp_csum"); - if (response.encrypted.checksum != csum) + if (response->encrypted.checksum != csum) goto protocol_error_free; spin_lock(&conn->channel_lock); for (i = 0; i < RXRPC_MAXCALLS; i++) { struct rxrpc_call *call; - u32 call_id = ntohl(response.encrypted.call_id[i]); + u32 call_id = ntohl(response->encrypted.call_id[i]); eproto = tracepoint_string("rxkad_rsp_callid"); if (call_id > INT_MAX) @@ -1153,12 +1159,12 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, eproto = tracepoint_string("rxkad_rsp_seq"); abort_code = RXKADOUTOFSEQUENCE; - if (ntohl(response.encrypted.inc_nonce) != conn->security_nonce + 1) + if (ntohl(response->encrypted.inc_nonce) != conn->security_nonce + 1) goto protocol_error_free; eproto = tracepoint_string("rxkad_rsp_level"); abort_code = RXKADLEVELFAIL; - level = ntohl(response.encrypted.level); + level = ntohl(response->encrypted.level); if (level > RXRPC_SECURITY_ENCRYPT) goto protocol_error_free; conn->params.security_level = level; @@ -1168,9 +1174,10 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, * as for a client connection */ ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno); if (ret < 0) - goto temporary_error_free; + goto temporary_error_free_ticket; kfree(ticket); + kfree(response); _leave(" = 0"); return 0; @@ -1179,12 +1186,15 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, protocol_error_free: kfree(ticket); protocol_error: + kfree(response); trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); *_abort_code = abort_code; return -EPROTO; -temporary_error_free: +temporary_error_free_ticket: kfree(ticket); +temporary_error_free_resp: + kfree(response); temporary_error: /* Ignore the response packet if we got a temporary error such as * ENOMEM. We just want to send the challenge again. Note that we