From 1234864f76bef11dca9fd668f81952dc164f3a6a Mon Sep 17 00:00:00 2001 From: yunhanw-google Date: Mon, 28 Aug 2023 11:34:38 -0700 Subject: [PATCH] Fix buffer leak and jni bytes object leak for jni command/write (#28913) --- .../java/CHIPDeviceController-JNI.cpp | 192 +++++++++--------- 1 file changed, 93 insertions(+), 99 deletions(-) diff --git a/src/controller/java/CHIPDeviceController-JNI.cpp b/src/controller/java/CHIPDeviceController-JNI.cpp index d4c0dd93bfc7c9..900440c7ce865c 100644 --- a/src/controller/java/CHIPDeviceController-JNI.cpp +++ b/src/controller/java/CHIPDeviceController-JNI.cpp @@ -1865,6 +1865,36 @@ JNI_METHOD(void, read) } } +// Convert Json to Tlv, and remove the outer structure +CHIP_ERROR ConvertJsonToTlvWithoutStruct(const std::string & json, MutableByteSpan & data) +{ + Platform::ScopedMemoryBufferWithSize buf; + VerifyOrReturnError(buf.Calloc(data.size()), CHIP_ERROR_NO_MEMORY); + MutableByteSpan dataWithStruct(buf.Get(), buf.AllocatedSize()); + ReturnErrorOnFailure(JsonToTlv(json, dataWithStruct)); + TLV::TLVReader tlvReader; + TLV::TLVType outerContainer = TLV::kTLVType_Structure; + tlvReader.Init(dataWithStruct); + ReturnErrorOnFailure(tlvReader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); + ReturnErrorOnFailure(tlvReader.EnterContainer(outerContainer)); + ReturnErrorOnFailure(tlvReader.Next()); + + TLV::TLVWriter tlvWrite; + tlvWrite.Init(data); + ReturnErrorOnFailure(tlvWrite.CopyElement(TLV::AnonymousTag(), tlvReader)); + ReturnErrorOnFailure(tlvWrite.Finalize()); + data.reduce_size(tlvWrite.GetLengthWritten()); + return CHIP_NO_ERROR; +} + +CHIP_ERROR PutPreencodedWriteAttribute(app::WriteClient & writeClient, app::ConcreteDataAttributePath & path, const ByteSpan & data) +{ + TLV::TLVReader reader; + reader.Init(data); + ReturnErrorOnFailure(reader.Next()); + return writeClient.PutPreencodedAttribute(path, reader); +} + JNI_METHOD(void, write) (JNIEnv * env, jobject self, jlong handle, jlong callbackHandle, jlong devicePtr, jobject attributeList, jint timedRequestTimeoutMs, jint imTimeoutMs) @@ -1875,8 +1905,6 @@ JNI_METHOD(void, write) auto callback = reinterpret_cast(callbackHandle); app::WriteClient * writeClient = nullptr; uint16_t convertedTimedRequestTimeoutMs = static_cast(timedRequestTimeoutMs); - bool hasValidTlv = false; - bool hasValidJson = false; ChipLogDetail(Controller, "IM write() called"); @@ -1909,9 +1937,6 @@ JNI_METHOD(void, write) jbyteArray tlvBytesObj = nullptr; bool hasDataVersion = false; Optional dataVersion = Optional(); - uint8_t * tlvBytes = nullptr; - size_t length = 0; - TLV::TLVReader reader; SuccessOrExit(err = JniReferences::GetInstance().GetListItem(attributeList, i, attributeItem)); SuccessOrExit(err = JniReferences::GetInstance().FindMethod( @@ -1955,14 +1980,12 @@ JNI_METHOD(void, write) tlvBytesObj = static_cast(env->CallObjectMethod(attributeItem, getTlvByteArrayMethod)); VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); + app::ConcreteDataAttributePath path(static_cast(endpointId), static_cast(clusterId), + static_cast(attributeId), dataVersion); if (tlvBytesObj != nullptr) { - jbyte * tlvBytesObjBytes = env->GetByteArrayElements(tlvBytesObj, nullptr); - VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); - length = static_cast(env->GetArrayLength(tlvBytesObj)); - VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); - tlvBytes = reinterpret_cast(tlvBytesObjBytes); - hasValidTlv = true; + JniByteArray tlvByteArray(env, tlvBytesObj); + SuccessOrExit(err = PutPreencodedWriteAttribute(*writeClient, path, tlvByteArray.byteSpan())); } else { @@ -1970,38 +1993,21 @@ JNI_METHOD(void, write) &getJsonStringMethod)); jstring jsonJniString = static_cast(env->CallObjectMethod(attributeItem, getJsonStringMethod)); VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); - if (jsonJniString != nullptr) - { - JniUtfString jsonUtfJniString(env, jsonJniString); - uint8_t bufWithStruct[chip::app::kMaxSecureSduLengthBytes] = { 0 }; - uint8_t buf[chip::app::kMaxSecureSduLengthBytes] = { 0 }; - TLV::TLVReader tlvReader; - TLV::TLVWriter tlvWrite; - TLV::TLVType outerContainer = TLV::kTLVType_Structure; - MutableByteSpan dataWithStruct{ bufWithStruct }; - MutableByteSpan data{ buf }; - SuccessOrExit(err = JsonToTlv(std::string(jsonUtfJniString.c_str(), jsonUtfJniString.size()), dataWithStruct)); - tlvReader.Init(dataWithStruct); - SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); - SuccessOrExit(err = tlvReader.EnterContainer(outerContainer)); - SuccessOrExit(err = tlvReader.Next()); - tlvWrite.Init(data); - SuccessOrExit(err = tlvWrite.CopyElement(TLV::AnonymousTag(), tlvReader)); - SuccessOrExit(err = tlvWrite.Finalize()); - tlvBytes = buf; - length = tlvWrite.GetLengthWritten(); - hasValidJson = true; - } + VerifyOrExit(jsonJniString != nullptr, err = CHIP_JNI_ERROR_EXCEPTION_THROWN); + JniUtfString jsonUtfJniString(env, jsonJniString); + std::string jsonString = std::string(jsonUtfJniString.c_str(), jsonUtfJniString.size()); + + // Context: Chunk write is supported in sdk, oversized list could be chunked in multiple message. When transforming + // JSON to TLV, we need know the actual size for tlv blob when handling JsonToTlv + // TODO: Implement memory auto-grow to get the actual size needed for tlv blob when transforming tlv to json. + // Workaround: Allocate memory using json string's size, which is large enough to hold the corresponding tlv blob + Platform::ScopedMemoryBufferWithSize tlvBytes; + size_t length = jsonUtfJniString.size(); + VerifyOrExit(tlvBytes.Calloc(length), err = CHIP_ERROR_NO_MEMORY); + MutableByteSpan data(tlvBytes.Get(), tlvBytes.AllocatedSize()); + SuccessOrExit(err = ConvertJsonToTlvWithoutStruct(jsonString, data)); + SuccessOrExit(err = PutPreencodedWriteAttribute(*writeClient, path, data)); } - VerifyOrExit(hasValidTlv || hasValidJson, err = CHIP_ERROR_INVALID_ARGUMENT); - - reader.Init(tlvBytes, length); - reader.Next(); - SuccessOrExit( - err = writeClient->PutPreencodedAttribute( - chip::app::ConcreteDataAttributePath(static_cast(endpointId), static_cast(clusterId), - static_cast(attributeId), dataVersion), - reader)); } err = writeClient->SendWriteRequest(device->GetSecureSession().Value(), @@ -2030,34 +2036,40 @@ JNI_METHOD(void, write) } } +CHIP_ERROR PutPreencodedInvokeRequest(app::CommandSender & commandSender, app::CommandPathParams & path, const ByteSpan & data) +{ + // PrepareCommand does nott create the struct container with kFields and copycontainer below sets the + // kFields container already + ReturnErrorOnFailure(commandSender.PrepareCommand(path, false /* aStartDataStruct */)); + TLV::TLVWriter * writer = commandSender.GetCommandDataIBTLVWriter(); + VerifyOrReturnError(writer != nullptr, CHIP_ERROR_INCORRECT_STATE); + TLV::TLVReader reader; + reader.Init(data); + ReturnErrorOnFailure(reader.Next()); + return writer->CopyContainer(TLV::ContextTag(app::CommandDataIB::Tag::kFields), reader); +} + JNI_METHOD(void, invoke) (JNIEnv * env, jobject self, jlong handle, jlong callbackHandle, jlong devicePtr, jobject invokeElement, jint timedRequestTimeoutMs, jint imTimeoutMs) { chip::DeviceLayer::StackLock lock; - CHIP_ERROR err = CHIP_NO_ERROR; - auto callback = reinterpret_cast(callbackHandle); - app::CommandSender * commandSender = nullptr; - uint32_t endpointId = 0; - uint32_t clusterId = 0; - uint32_t commandId = 0; - jmethodID getEndpointIdMethod = nullptr; - jmethodID getClusterIdMethod = nullptr; - jmethodID getCommandIdMethod = nullptr; - jmethodID getTlvByteArrayMethod = nullptr; - jmethodID getJsonStringMethod = nullptr; - jobject endpointIdObj = nullptr; - jobject clusterIdObj = nullptr; - jobject commandIdObj = nullptr; - jbyteArray tlvBytesObj = nullptr; - TLV::TLVReader reader; - TLV::TLVWriter * writer = nullptr; - uint8_t * tlvBytes = nullptr; - size_t length = 0; - bool hasValidTlv = false; - bool hasValidJson = false; + CHIP_ERROR err = CHIP_NO_ERROR; + auto callback = reinterpret_cast(callbackHandle); + app::CommandSender * commandSender = nullptr; + uint32_t endpointId = 0; + uint32_t clusterId = 0; + uint32_t commandId = 0; + jmethodID getEndpointIdMethod = nullptr; + jmethodID getClusterIdMethod = nullptr; + jmethodID getCommandIdMethod = nullptr; + jmethodID getTlvByteArrayMethod = nullptr; + jmethodID getJsonStringMethod = nullptr; + jobject endpointIdObj = nullptr; + jobject clusterIdObj = nullptr; + jobject commandIdObj = nullptr; + jbyteArray tlvBytesObj = nullptr; uint16_t convertedTimedRequestTimeoutMs = static_cast(timedRequestTimeoutMs); - ChipLogDetail(Controller, "IM invoke() called"); DeviceProxy * device = reinterpret_cast(devicePtr); @@ -2093,49 +2105,32 @@ JNI_METHOD(void, invoke) tlvBytesObj = static_cast(env->CallObjectMethod(invokeElement, getTlvByteArrayMethod)); VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); - if (tlvBytesObj != nullptr) - { - jbyte * tlvBytesObjBytes = env->GetByteArrayElements(tlvBytesObj, nullptr); - VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); - length = static_cast(env->GetArrayLength(tlvBytesObj)); - VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); - tlvBytes = reinterpret_cast(tlvBytesObjBytes); - hasValidTlv = true; - } - else { - SuccessOrExit(err = JniReferences::GetInstance().FindMethod(env, invokeElement, "getJsonString", "()Ljava/lang/String;", - &getJsonStringMethod)); - jstring jsonJniString = static_cast(env->CallObjectMethod(invokeElement, getJsonStringMethod)); - VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); - if (jsonJniString != nullptr) + app::CommandPathParams path(static_cast(endpointId), /* group id */ 0, static_cast(clusterId), + static_cast(commandId), app::CommandPathFlags::kEndpointIdValid); + if (tlvBytesObj != nullptr) { + JniByteArray tlvBytesObjBytes(env, tlvBytesObj); + SuccessOrExit(err = PutPreencodedInvokeRequest(*commandSender, path, tlvBytesObjBytes.byteSpan())); + } + else + { + SuccessOrExit(err = JniReferences::GetInstance().FindMethod(env, invokeElement, "getJsonString", "()Ljava/lang/String;", + &getJsonStringMethod)); + jstring jsonJniString = static_cast(env->CallObjectMethod(invokeElement, getJsonStringMethod)); + VerifyOrExit(!env->ExceptionCheck(), err = CHIP_JNI_ERROR_EXCEPTION_THROWN); + VerifyOrExit(jsonJniString != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT); JniUtfString jsonUtfJniString(env, jsonJniString); - uint8_t buf[chip::app::kMaxSecureSduLengthBytes] = { 0 }; - MutableByteSpan tlvEncodingLocal{ buf }; + // The invoke does not support chunk, kMaxSecureSduLengthBytes should be enough for command json blob + uint8_t tlvBytes[chip::app::kMaxSecureSduLengthBytes] = { 0 }; + MutableByteSpan tlvEncodingLocal{ tlvBytes }; SuccessOrExit(err = JsonToTlv(std::string(jsonUtfJniString.c_str(), jsonUtfJniString.size()), tlvEncodingLocal)); - tlvBytes = tlvEncodingLocal.data(); - length = tlvEncodingLocal.size(); - hasValidJson = true; + SuccessOrExit(err = PutPreencodedInvokeRequest(*commandSender, path, tlvEncodingLocal)); } } - VerifyOrExit(hasValidTlv || hasValidJson, err = CHIP_ERROR_INVALID_ARGUMENT); - - SuccessOrExit(err = commandSender->PrepareCommand(app::CommandPathParams(static_cast(endpointId), /* group id */ 0, - static_cast(clusterId), - static_cast(commandId), - app::CommandPathFlags::kEndpointIdValid), - false)); - - writer = commandSender->GetCommandDataIBTLVWriter(); - VerifyOrExit(writer != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - reader.Init(tlvBytes, static_cast(length)); - reader.Next(); - SuccessOrExit(err = writer->CopyContainer(TLV::ContextTag(app::CommandDataIB::Tag::kFields), reader)); SuccessOrExit(err = commandSender->FinishCommand(convertedTimedRequestTimeoutMs != 0 ? Optional(convertedTimedRequestTimeoutMs) : Optional::Missing())); - SuccessOrExit(err = commandSender->SendCommandRequest(device->GetSecureSession().Value(), imTimeoutMs != 0 ? MakeOptional(System::Clock::Milliseconds32(imTimeoutMs)) @@ -2144,7 +2139,6 @@ JNI_METHOD(void, invoke) callback->mCommandSender = commandSender; exit: - if (err != CHIP_NO_ERROR) { ChipLogError(Controller, "JNI IM Invoke Error: %s", err.AsString());