+ getEncryptedTensorMapMap();
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getEncryptedTensorMapOrDefault(
+ java.lang.String key,
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor defaultValue);
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getEncryptedTensorMapOrThrow(
+ java.lang.String key);
}
/**
*
@@ -1163,11 +1078,10 @@ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTensorMap
*
* Protobuf type {@code TensorMap}
*/
- public static final class TensorMap extends
+ public static final class TensorMap extends
com.google.protobuf.GeneratedMessageV3 implements
// @@protoc_insertion_point(message_implements:TensorMap)
TensorMapOrBuilder {
- private static final long serialVersionUID = 0L;
// Use TensorMap.newBuilder() to construct.
private TensorMap(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
super(builder);
@@ -1175,29 +1089,17 @@ private TensorMap(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
private TensorMap() {
}
- @java.lang.Override
- @SuppressWarnings({"unused"})
- protected java.lang.Object newInstance(
- UnusedPrivateParameter unused) {
- return new TensorMap();
- }
-
@java.lang.Override
public final com.google.protobuf.UnknownFieldSet
getUnknownFields() {
- return this.unknownFields;
+ return com.google.protobuf.UnknownFieldSet.getDefaultInstance();
}
private TensorMap(
com.google.protobuf.CodedInputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
throws com.google.protobuf.InvalidProtocolBufferException {
this();
- if (extensionRegistry == null) {
- throw new java.lang.NullPointerException();
- }
int mutable_bitField0_ = 0;
- com.google.protobuf.UnknownFieldSet.Builder unknownFields =
- com.google.protobuf.UnknownFieldSet.newBuilder();
try {
boolean done = false;
while (!done) {
@@ -1206,6 +1108,12 @@ private TensorMap(
case 0:
done = true;
break;
+ default: {
+ if (!input.skipField(tag)) {
+ done = true;
+ }
+ break;
+ }
case 10: {
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.Builder subBuilder = null;
if (metaData_ != null) {
@@ -1220,23 +1128,27 @@ private TensorMap(
break;
}
case 18: {
- if (!((mutable_bitField0_ & 0x00000001) != 0)) {
+ if (!((mutable_bitField0_ & 0x00000002) == 0x00000002)) {
tensorMap_ = com.google.protobuf.MapField.newMapField(
TensorMapDefaultEntryHolder.defaultEntry);
- mutable_bitField0_ |= 0x00000001;
+ mutable_bitField0_ |= 0x00000002;
}
com.google.protobuf.MapEntry
- tensorMap__ = input.readMessage(
+ tensorMap = input.readMessage(
TensorMapDefaultEntryHolder.defaultEntry.getParserForType(), extensionRegistry);
- tensorMap_.getMutableMap().put(
- tensorMap__.getKey(), tensorMap__.getValue());
+ tensorMap_.getMutableMap().put(tensorMap.getKey(), tensorMap.getValue());
break;
}
- default: {
- if (!parseUnknownField(
- input, unknownFields, extensionRegistry, tag)) {
- done = true;
+ case 26: {
+ if (!((mutable_bitField0_ & 0x00000004) == 0x00000004)) {
+ encryptedTensorMap_ = com.google.protobuf.MapField.newMapField(
+ EncryptedTensorMapDefaultEntryHolder.defaultEntry);
+ mutable_bitField0_ |= 0x00000004;
}
+ com.google.protobuf.MapEntry
+ encryptedTensorMap = input.readMessage(
+ EncryptedTensorMapDefaultEntryHolder.defaultEntry.getParserForType(), extensionRegistry);
+ encryptedTensorMap_.getMutableMap().put(encryptedTensorMap.getKey(), encryptedTensorMap.getValue());
break;
}
}
@@ -1247,7 +1159,6 @@ private TensorMap(
throw new com.google.protobuf.InvalidProtocolBufferException(
e).setUnfinishedMessage(this);
} finally {
- this.unknownFields = unknownFields.build();
makeExtensionsImmutable();
}
}
@@ -1257,18 +1168,18 @@ private TensorMap(
}
@SuppressWarnings({"rawtypes"})
- @java.lang.Override
protected com.google.protobuf.MapField internalGetMapField(
int number) {
switch (number) {
case 2:
return internalGetTensorMap();
+ case 3:
+ return internalGetEncryptedTensorMap();
default:
throw new RuntimeException(
"Invalid map field number: " + number);
}
}
- @java.lang.Override
protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
internalGetFieldAccessorTable() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_TensorMap_fieldAccessorTable
@@ -1276,28 +1187,24 @@ protected com.google.protobuf.MapField internalGetMapField(
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap.class, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap.Builder.class);
}
+ private int bitField0_;
public static final int METADATA_FIELD_NUMBER = 1;
private com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData metaData_;
/**
- * .MetaData metaData = 1;
- * @return Whether the metaData field is set.
+ * optional .MetaData metaData = 1;
*/
- @java.lang.Override
public boolean hasMetaData() {
return metaData_ != null;
}
/**
- * .MetaData metaData = 1;
- * @return The metaData.
+ * optional .MetaData metaData = 1;
*/
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData getMetaData() {
return metaData_ == null ? com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.getDefaultInstance() : metaData_;
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaDataOrBuilder getMetaDataOrBuilder() {
return getMetaData();
}
@@ -1332,7 +1239,6 @@ public int getTensorMapCount() {
* map<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public boolean containsTensorMap(
java.lang.String key) {
if (key == null) { throw new java.lang.NullPointerException(); }
@@ -1341,7 +1247,6 @@ public boolean containsTensorMap(
/**
* Use {@link #getTensorMapMap()} instead.
*/
- @java.lang.Override
@java.lang.Deprecated
public java.util.Map getTensorMap() {
return getTensorMapMap();
@@ -1349,7 +1254,6 @@ public java.util.Mapmap<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public java.util.Map getTensorMapMap() {
return internalGetTensorMap().getMap();
@@ -1357,7 +1261,6 @@ public java.util.Mapmap<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTensorMapOrDefault(
java.lang.String key,
@@ -1370,7 +1273,6 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTe
/**
* map<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTensorMapOrThrow(
java.lang.String key) {
@@ -1383,8 +1285,83 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTe
return map.get(key);
}
+ public static final int ENCRYPTEDTENSORMAP_FIELD_NUMBER = 3;
+ private static final class EncryptedTensorMapDefaultEntryHolder {
+ static final com.google.protobuf.MapEntry<
+ java.lang.String, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor> defaultEntry =
+ com.google.protobuf.MapEntry
+ .newDefaultInstance(
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_TensorMap_EncryptedTensorMapEntry_descriptor,
+ com.google.protobuf.WireFormat.FieldType.STRING,
+ "",
+ com.google.protobuf.WireFormat.FieldType.MESSAGE,
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.getDefaultInstance());
+ }
+ private com.google.protobuf.MapField<
+ java.lang.String, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor> encryptedTensorMap_;
+ private com.google.protobuf.MapField
+ internalGetEncryptedTensorMap() {
+ if (encryptedTensorMap_ == null) {
+ return com.google.protobuf.MapField.emptyMapField(
+ EncryptedTensorMapDefaultEntryHolder.defaultEntry);
+ }
+ return encryptedTensorMap_;
+ }
+
+ public int getEncryptedTensorMapCount() {
+ return internalGetEncryptedTensorMap().getMap().size();
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public boolean containsEncryptedTensorMap(
+ java.lang.String key) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ return internalGetEncryptedTensorMap().getMap().containsKey(key);
+ }
+ /**
+ * Use {@link #getEncryptedTensorMapMap()} instead.
+ */
+ @java.lang.Deprecated
+ public java.util.Map getEncryptedTensorMap() {
+ return getEncryptedTensorMapMap();
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public java.util.Map getEncryptedTensorMapMap() {
+ return internalGetEncryptedTensorMap().getMap();
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getEncryptedTensorMapOrDefault(
+ java.lang.String key,
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor defaultValue) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ java.util.Map map =
+ internalGetEncryptedTensorMap().getMap();
+ return map.containsKey(key) ? map.get(key) : defaultValue;
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getEncryptedTensorMapOrThrow(
+ java.lang.String key) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ java.util.Map map =
+ internalGetEncryptedTensorMap().getMap();
+ if (!map.containsKey(key)) {
+ throw new java.lang.IllegalArgumentException();
+ }
+ return map.get(key);
+ }
+
private byte memoizedIsInitialized = -1;
- @java.lang.Override
public final boolean isInitialized() {
byte isInitialized = memoizedIsInitialized;
if (isInitialized == 1) return true;
@@ -1394,22 +1371,31 @@ public final boolean isInitialized() {
return true;
}
- @java.lang.Override
public void writeTo(com.google.protobuf.CodedOutputStream output)
throws java.io.IOException {
if (metaData_ != null) {
output.writeMessage(1, getMetaData());
}
- com.google.protobuf.GeneratedMessageV3
- .serializeStringMapTo(
- output,
- internalGetTensorMap(),
- TensorMapDefaultEntryHolder.defaultEntry,
- 2);
- unknownFields.writeTo(output);
+ for (java.util.Map.Entry entry
+ : internalGetTensorMap().getMap().entrySet()) {
+ com.google.protobuf.MapEntry
+ tensorMap = TensorMapDefaultEntryHolder.defaultEntry.newBuilderForType()
+ .setKey(entry.getKey())
+ .setValue(entry.getValue())
+ .build();
+ output.writeMessage(2, tensorMap);
+ }
+ for (java.util.Map.Entry entry
+ : internalGetEncryptedTensorMap().getMap().entrySet()) {
+ com.google.protobuf.MapEntry
+ encryptedTensorMap = EncryptedTensorMapDefaultEntryHolder.defaultEntry.newBuilderForType()
+ .setKey(entry.getKey())
+ .setValue(entry.getValue())
+ .build();
+ output.writeMessage(3, encryptedTensorMap);
+ }
}
- @java.lang.Override
public int getSerializedSize() {
int size = memoizedSize;
if (size != -1) return size;
@@ -1422,18 +1408,28 @@ public int getSerializedSize() {
for (java.util.Map.Entry entry
: internalGetTensorMap().getMap().entrySet()) {
com.google.protobuf.MapEntry
- tensorMap__ = TensorMapDefaultEntryHolder.defaultEntry.newBuilderForType()
+ tensorMap = TensorMapDefaultEntryHolder.defaultEntry.newBuilderForType()
+ .setKey(entry.getKey())
+ .setValue(entry.getValue())
+ .build();
+ size += com.google.protobuf.CodedOutputStream
+ .computeMessageSize(2, tensorMap);
+ }
+ for (java.util.Map.Entry entry
+ : internalGetEncryptedTensorMap().getMap().entrySet()) {
+ com.google.protobuf.MapEntry
+ encryptedTensorMap = EncryptedTensorMapDefaultEntryHolder.defaultEntry.newBuilderForType()
.setKey(entry.getKey())
.setValue(entry.getValue())
.build();
size += com.google.protobuf.CodedOutputStream
- .computeMessageSize(2, tensorMap__);
+ .computeMessageSize(3, encryptedTensorMap);
}
- size += unknownFields.getSerializedSize();
memoizedSize = size;
return size;
}
+ private static final long serialVersionUID = 0L;
@java.lang.Override
public boolean equals(final java.lang.Object obj) {
if (obj == this) {
@@ -1444,15 +1440,17 @@ public boolean equals(final java.lang.Object obj) {
}
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap other = (com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap) obj;
- if (hasMetaData() != other.hasMetaData()) return false;
+ boolean result = true;
+ result = result && (hasMetaData() == other.hasMetaData());
if (hasMetaData()) {
- if (!getMetaData()
- .equals(other.getMetaData())) return false;
+ result = result && getMetaData()
+ .equals(other.getMetaData());
}
- if (!internalGetTensorMap().equals(
- other.internalGetTensorMap())) return false;
- if (!unknownFields.equals(other.unknownFields)) return false;
- return true;
+ result = result && internalGetTensorMap().equals(
+ other.internalGetTensorMap());
+ result = result && internalGetEncryptedTensorMap().equals(
+ other.internalGetEncryptedTensorMap());
+ return result;
}
@java.lang.Override
@@ -1461,7 +1459,7 @@ public int hashCode() {
return memoizedHashCode;
}
int hash = 41;
- hash = (19 * hash) + getDescriptor().hashCode();
+ hash = (19 * hash) + getDescriptorForType().hashCode();
if (hasMetaData()) {
hash = (37 * hash) + METADATA_FIELD_NUMBER;
hash = (53 * hash) + getMetaData().hashCode();
@@ -1470,22 +1468,15 @@ public int hashCode() {
hash = (37 * hash) + TENSORMAP_FIELD_NUMBER;
hash = (53 * hash) + internalGetTensorMap().hashCode();
}
+ if (!internalGetEncryptedTensorMap().getMap().isEmpty()) {
+ hash = (37 * hash) + ENCRYPTEDTENSORMAP_FIELD_NUMBER;
+ hash = (53 * hash) + internalGetEncryptedTensorMap().hashCode();
+ }
hash = (29 * hash) + unknownFields.hashCode();
memoizedHashCode = hash;
return hash;
}
- public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap parseFrom(
- java.nio.ByteBuffer data)
- throws com.google.protobuf.InvalidProtocolBufferException {
- return PARSER.parseFrom(data);
- }
- public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap parseFrom(
- java.nio.ByteBuffer data,
- com.google.protobuf.ExtensionRegistryLite extensionRegistry)
- throws com.google.protobuf.InvalidProtocolBufferException {
- return PARSER.parseFrom(data, extensionRegistry);
- }
public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap parseFrom(
com.google.protobuf.ByteString data)
throws com.google.protobuf.InvalidProtocolBufferException {
@@ -1545,7 +1536,6 @@ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap
.parseWithIOException(PARSER, input, extensionRegistry);
}
- @java.lang.Override
public Builder newBuilderForType() { return newBuilder(); }
public static Builder newBuilder() {
return DEFAULT_INSTANCE.toBuilder();
@@ -1553,7 +1543,6 @@ public static Builder newBuilder() {
public static Builder newBuilder(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap prototype) {
return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype);
}
- @java.lang.Override
public Builder toBuilder() {
return this == DEFAULT_INSTANCE
? new Builder() : new Builder().mergeFrom(this);
@@ -1586,6 +1575,8 @@ protected com.google.protobuf.MapField internalGetMapField(
switch (number) {
case 2:
return internalGetTensorMap();
+ case 3:
+ return internalGetEncryptedTensorMap();
default:
throw new RuntimeException(
"Invalid map field number: " + number);
@@ -1597,12 +1588,13 @@ protected com.google.protobuf.MapField internalGetMutableMapField(
switch (number) {
case 2:
return internalGetMutableTensorMap();
+ case 3:
+ return internalGetMutableEncryptedTensorMap();
default:
throw new RuntimeException(
"Invalid map field number: " + number);
}
}
- @java.lang.Override
protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
internalGetFieldAccessorTable() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_TensorMap_fieldAccessorTable
@@ -1625,7 +1617,6 @@ private void maybeForceBuilderInitialization() {
.alwaysUseFieldBuilders) {
}
}
- @java.lang.Override
public Builder clear() {
super.clear();
if (metaDataBuilder_ == null) {
@@ -1635,21 +1626,19 @@ public Builder clear() {
metaDataBuilder_ = null;
}
internalGetMutableTensorMap().clear();
+ internalGetMutableEncryptedTensorMap().clear();
return this;
}
- @java.lang.Override
public com.google.protobuf.Descriptors.Descriptor
getDescriptorForType() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_TensorMap_descriptor;
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap getDefaultInstanceForType() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap.getDefaultInstance();
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap build() {
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap result = buildPartial();
if (!result.isInitialized()) {
@@ -1658,10 +1647,10 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap build()
return result;
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap buildPartial() {
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap result = new com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap(this);
int from_bitField0_ = bitField0_;
+ int to_bitField0_ = 0;
if (metaDataBuilder_ == null) {
result.metaData_ = metaData_;
} else {
@@ -1669,43 +1658,39 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap buildPa
}
result.tensorMap_ = internalGetTensorMap();
result.tensorMap_.makeImmutable();
+ result.encryptedTensorMap_ = internalGetEncryptedTensorMap();
+ result.encryptedTensorMap_.makeImmutable();
+ result.bitField0_ = to_bitField0_;
onBuilt();
return result;
}
- @java.lang.Override
public Builder clone() {
- return super.clone();
+ return (Builder) super.clone();
}
- @java.lang.Override
public Builder setField(
com.google.protobuf.Descriptors.FieldDescriptor field,
- java.lang.Object value) {
- return super.setField(field, value);
+ Object value) {
+ return (Builder) super.setField(field, value);
}
- @java.lang.Override
public Builder clearField(
com.google.protobuf.Descriptors.FieldDescriptor field) {
- return super.clearField(field);
+ return (Builder) super.clearField(field);
}
- @java.lang.Override
public Builder clearOneof(
com.google.protobuf.Descriptors.OneofDescriptor oneof) {
- return super.clearOneof(oneof);
+ return (Builder) super.clearOneof(oneof);
}
- @java.lang.Override
public Builder setRepeatedField(
com.google.protobuf.Descriptors.FieldDescriptor field,
- int index, java.lang.Object value) {
- return super.setRepeatedField(field, index, value);
+ int index, Object value) {
+ return (Builder) super.setRepeatedField(field, index, value);
}
- @java.lang.Override
public Builder addRepeatedField(
com.google.protobuf.Descriptors.FieldDescriptor field,
- java.lang.Object value) {
- return super.addRepeatedField(field, value);
+ Object value) {
+ return (Builder) super.addRepeatedField(field, value);
}
- @java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap) {
return mergeFrom((com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap)other);
@@ -1722,17 +1707,16 @@ public Builder mergeFrom(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto
}
internalGetMutableTensorMap().mergeFrom(
other.internalGetTensorMap());
- this.mergeUnknownFields(other.unknownFields);
+ internalGetMutableEncryptedTensorMap().mergeFrom(
+ other.internalGetEncryptedTensorMap());
onChanged();
return this;
}
- @java.lang.Override
public final boolean isInitialized() {
return true;
}
- @java.lang.Override
public Builder mergeFrom(
com.google.protobuf.CodedInputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -1752,19 +1736,17 @@ public Builder mergeFrom(
}
private int bitField0_;
- private com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData metaData_;
+ private com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData metaData_ = null;
private com.google.protobuf.SingleFieldBuilderV3<
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.Builder, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaDataOrBuilder> metaDataBuilder_;
/**
- * .MetaData metaData = 1;
- * @return Whether the metaData field is set.
+ * optional .MetaData metaData = 1;
*/
public boolean hasMetaData() {
return metaDataBuilder_ != null || metaData_ != null;
}
/**
- * .MetaData metaData = 1;
- * @return The metaData.
+ * optional .MetaData metaData = 1;
*/
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData getMetaData() {
if (metaDataBuilder_ == null) {
@@ -1774,7 +1756,7 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData getMetaD
}
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
public Builder setMetaData(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData value) {
if (metaDataBuilder_ == null) {
@@ -1790,7 +1772,7 @@ public Builder setMetaData(com.intel.analytics.bigdl.ppml.fl.generated.FlBasePro
return this;
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
public Builder setMetaData(
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.Builder builderForValue) {
@@ -1804,7 +1786,7 @@ public Builder setMetaData(
return this;
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
public Builder mergeMetaData(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData value) {
if (metaDataBuilder_ == null) {
@@ -1822,7 +1804,7 @@ public Builder mergeMetaData(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseP
return this;
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
public Builder clearMetaData() {
if (metaDataBuilder_ == null) {
@@ -1836,7 +1818,7 @@ public Builder clearMetaData() {
return this;
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.Builder getMetaDataBuilder() {
@@ -1844,7 +1826,7 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.Builder
return getMetaDataFieldBuilder().getBuilder();
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaDataOrBuilder getMetaDataOrBuilder() {
if (metaDataBuilder_ != null) {
@@ -1855,7 +1837,7 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaDataOrBuilder
}
}
/**
- * .MetaData metaData = 1;
+ * optional .MetaData metaData = 1;
*/
private com.google.protobuf.SingleFieldBuilderV3<
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.Builder, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaDataOrBuilder>
@@ -1901,7 +1883,6 @@ public int getTensorMapCount() {
* map<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public boolean containsTensorMap(
java.lang.String key) {
if (key == null) { throw new java.lang.NullPointerException(); }
@@ -1910,7 +1891,6 @@ public boolean containsTensorMap(
/**
* Use {@link #getTensorMapMap()} instead.
*/
- @java.lang.Override
@java.lang.Deprecated
public java.util.Map getTensorMap() {
return getTensorMapMap();
@@ -1918,7 +1898,6 @@ public java.util.Mapmap<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public java.util.Map getTensorMapMap() {
return internalGetTensorMap().getMap();
@@ -1926,7 +1905,6 @@ public java.util.Mapmap<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTensorMapOrDefault(
java.lang.String key,
@@ -1939,7 +1917,6 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTe
/**
* map<string, .FloatTensor> tensorMap = 2;
*/
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTensorMapOrThrow(
java.lang.String key) {
@@ -1953,8 +1930,7 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor getTe
}
public Builder clearTensorMap() {
- internalGetMutableTensorMap().getMutableMap()
- .clear();
+ getMutableTensorMap().clear();
return this;
}
/**
@@ -1964,8 +1940,7 @@ public Builder clearTensorMap() {
public Builder removeTensorMap(
java.lang.String key) {
if (key == null) { throw new java.lang.NullPointerException(); }
- internalGetMutableTensorMap().getMutableMap()
- .remove(key);
+ getMutableTensorMap().remove(key);
return this;
}
/**
@@ -1984,8 +1959,7 @@ public Builder putTensorMap(
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.FloatTensor value) {
if (key == null) { throw new java.lang.NullPointerException(); }
if (value == null) { throw new java.lang.NullPointerException(); }
- internalGetMutableTensorMap().getMutableMap()
- .put(key, value);
+ getMutableTensorMap().put(key, value);
return this;
}
/**
@@ -1994,20 +1968,136 @@ public Builder putTensorMap(
public Builder putAllTensorMap(
java.util.Map values) {
- internalGetMutableTensorMap().getMutableMap()
- .putAll(values);
+ getMutableTensorMap().putAll(values);
+ return this;
+ }
+
+ private com.google.protobuf.MapField<
+ java.lang.String, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor> encryptedTensorMap_;
+ private com.google.protobuf.MapField
+ internalGetEncryptedTensorMap() {
+ if (encryptedTensorMap_ == null) {
+ return com.google.protobuf.MapField.emptyMapField(
+ EncryptedTensorMapDefaultEntryHolder.defaultEntry);
+ }
+ return encryptedTensorMap_;
+ }
+ private com.google.protobuf.MapField
+ internalGetMutableEncryptedTensorMap() {
+ onChanged();;
+ if (encryptedTensorMap_ == null) {
+ encryptedTensorMap_ = com.google.protobuf.MapField.newMapField(
+ EncryptedTensorMapDefaultEntryHolder.defaultEntry);
+ }
+ if (!encryptedTensorMap_.isMutable()) {
+ encryptedTensorMap_ = encryptedTensorMap_.copy();
+ }
+ return encryptedTensorMap_;
+ }
+
+ public int getEncryptedTensorMapCount() {
+ return internalGetEncryptedTensorMap().getMap().size();
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public boolean containsEncryptedTensorMap(
+ java.lang.String key) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ return internalGetEncryptedTensorMap().getMap().containsKey(key);
+ }
+ /**
+ * Use {@link #getEncryptedTensorMapMap()} instead.
+ */
+ @java.lang.Deprecated
+ public java.util.Map getEncryptedTensorMap() {
+ return getEncryptedTensorMapMap();
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public java.util.Map getEncryptedTensorMapMap() {
+ return internalGetEncryptedTensorMap().getMap();
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getEncryptedTensorMapOrDefault(
+ java.lang.String key,
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor defaultValue) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ java.util.Map map =
+ internalGetEncryptedTensorMap().getMap();
+ return map.containsKey(key) ? map.get(key) : defaultValue;
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getEncryptedTensorMapOrThrow(
+ java.lang.String key) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ java.util.Map map =
+ internalGetEncryptedTensorMap().getMap();
+ if (!map.containsKey(key)) {
+ throw new java.lang.IllegalArgumentException();
+ }
+ return map.get(key);
+ }
+
+ public Builder clearEncryptedTensorMap() {
+ getMutableEncryptedTensorMap().clear();
+ return this;
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public Builder removeEncryptedTensorMap(
+ java.lang.String key) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ getMutableEncryptedTensorMap().remove(key);
+ return this;
+ }
+ /**
+ * Use alternate mutation accessors instead.
+ */
+ @java.lang.Deprecated
+ public java.util.Map
+ getMutableEncryptedTensorMap() {
+ return internalGetMutableEncryptedTensorMap().getMutableMap();
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+ public Builder putEncryptedTensorMap(
+ java.lang.String key,
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor value) {
+ if (key == null) { throw new java.lang.NullPointerException(); }
+ if (value == null) { throw new java.lang.NullPointerException(); }
+ getMutableEncryptedTensorMap().put(key, value);
+ return this;
+ }
+ /**
+ * map<string, .EncryptedTensor> encryptedTensorMap = 3;
+ */
+
+ public Builder putAllEncryptedTensorMap(
+ java.util.Map values) {
+ getMutableEncryptedTensorMap().putAll(values);
return this;
}
- @java.lang.Override
public final Builder setUnknownFields(
final com.google.protobuf.UnknownFieldSet unknownFields) {
- return super.setUnknownFields(unknownFields);
+ return this;
}
- @java.lang.Override
public final Builder mergeUnknownFields(
final com.google.protobuf.UnknownFieldSet unknownFields) {
- return super.mergeUnknownFields(unknownFields);
+ return this;
}
@@ -2026,12 +2116,11 @@ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap
private static final com.google.protobuf.Parser
PARSER = new com.google.protobuf.AbstractParser() {
- @java.lang.Override
public TensorMap parsePartialFrom(
com.google.protobuf.CodedInputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
throws com.google.protobuf.InvalidProtocolBufferException {
- return new TensorMap(input, extensionRegistry);
+ return new TensorMap(input, extensionRegistry);
}
};
@@ -2044,7 +2133,6 @@ public com.google.protobuf.Parser getParserForType() {
return PARSER;
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.TensorMap getDefaultInstanceForType() {
return DEFAULT_INSTANCE;
}
@@ -2056,20 +2144,17 @@ public interface MetaDataOrBuilder extends
com.google.protobuf.MessageOrBuilder {
/**
- * string name = 1;
- * @return The name.
+ * optional string name = 1;
*/
java.lang.String getName();
/**
- * string name = 1;
- * @return The bytes for name.
+ * optional string name = 1;
*/
com.google.protobuf.ByteString
getNameBytes();
/**
- * int32 version = 2;
- * @return The version.
+ * optional int32 version = 2;
*/
int getVersion();
}
@@ -2079,41 +2164,30 @@ public interface MetaDataOrBuilder extends
*
* Protobuf type {@code MetaData}
*/
- public static final class MetaData extends
+ public static final class MetaData extends
com.google.protobuf.GeneratedMessageV3 implements
// @@protoc_insertion_point(message_implements:MetaData)
MetaDataOrBuilder {
- private static final long serialVersionUID = 0L;
// Use MetaData.newBuilder() to construct.
private MetaData(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
super(builder);
}
private MetaData() {
name_ = "";
- }
-
- @java.lang.Override
- @SuppressWarnings({"unused"})
- protected java.lang.Object newInstance(
- UnusedPrivateParameter unused) {
- return new MetaData();
+ version_ = 0;
}
@java.lang.Override
public final com.google.protobuf.UnknownFieldSet
getUnknownFields() {
- return this.unknownFields;
+ return com.google.protobuf.UnknownFieldSet.getDefaultInstance();
}
private MetaData(
com.google.protobuf.CodedInputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
throws com.google.protobuf.InvalidProtocolBufferException {
this();
- if (extensionRegistry == null) {
- throw new java.lang.NullPointerException();
- }
- com.google.protobuf.UnknownFieldSet.Builder unknownFields =
- com.google.protobuf.UnknownFieldSet.newBuilder();
+ int mutable_bitField0_ = 0;
try {
boolean done = false;
while (!done) {
@@ -2122,6 +2196,12 @@ private MetaData(
case 0:
done = true;
break;
+ default: {
+ if (!input.skipField(tag)) {
+ done = true;
+ }
+ break;
+ }
case 10: {
java.lang.String s = input.readStringRequireUtf8();
@@ -2133,13 +2213,6 @@ private MetaData(
version_ = input.readInt32();
break;
}
- default: {
- if (!parseUnknownField(
- input, unknownFields, extensionRegistry, tag)) {
- done = true;
- }
- break;
- }
}
}
} catch (com.google.protobuf.InvalidProtocolBufferException e) {
@@ -2148,7 +2221,6 @@ private MetaData(
throw new com.google.protobuf.InvalidProtocolBufferException(
e).setUnfinishedMessage(this);
} finally {
- this.unknownFields = unknownFields.build();
makeExtensionsImmutable();
}
}
@@ -2157,7 +2229,6 @@ private MetaData(
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_MetaData_descriptor;
}
- @java.lang.Override
protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
internalGetFieldAccessorTable() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_MetaData_fieldAccessorTable
@@ -2168,10 +2239,8 @@ private MetaData(
public static final int NAME_FIELD_NUMBER = 1;
private volatile java.lang.Object name_;
/**
- * string name = 1;
- * @return The name.
+ * optional string name = 1;
*/
- @java.lang.Override
public java.lang.String getName() {
java.lang.Object ref = name_;
if (ref instanceof java.lang.String) {
@@ -2185,10 +2254,8 @@ public java.lang.String getName() {
}
}
/**
- * string name = 1;
- * @return The bytes for name.
+ * optional string name = 1;
*/
- @java.lang.Override
public com.google.protobuf.ByteString
getNameBytes() {
java.lang.Object ref = name_;
@@ -2206,16 +2273,13 @@ public java.lang.String getName() {
public static final int VERSION_FIELD_NUMBER = 2;
private int version_;
/**
- * int32 version = 2;
- * @return The version.
+ * optional int32 version = 2;
*/
- @java.lang.Override
public int getVersion() {
return version_;
}
private byte memoizedIsInitialized = -1;
- @java.lang.Override
public final boolean isInitialized() {
byte isInitialized = memoizedIsInitialized;
if (isInitialized == 1) return true;
@@ -2225,7 +2289,6 @@ public final boolean isInitialized() {
return true;
}
- @java.lang.Override
public void writeTo(com.google.protobuf.CodedOutputStream output)
throws java.io.IOException {
if (!getNameBytes().isEmpty()) {
@@ -2234,10 +2297,8 @@ public void writeTo(com.google.protobuf.CodedOutputStream output)
if (version_ != 0) {
output.writeInt32(2, version_);
}
- unknownFields.writeTo(output);
}
- @java.lang.Override
public int getSerializedSize() {
int size = memoizedSize;
if (size != -1) return size;
@@ -2250,11 +2311,11 @@ public int getSerializedSize() {
size += com.google.protobuf.CodedOutputStream
.computeInt32Size(2, version_);
}
- size += unknownFields.getSerializedSize();
memoizedSize = size;
return size;
}
+ private static final long serialVersionUID = 0L;
@java.lang.Override
public boolean equals(final java.lang.Object obj) {
if (obj == this) {
@@ -2265,12 +2326,12 @@ public boolean equals(final java.lang.Object obj) {
}
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData other = (com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData) obj;
- if (!getName()
- .equals(other.getName())) return false;
- if (getVersion()
- != other.getVersion()) return false;
- if (!unknownFields.equals(other.unknownFields)) return false;
- return true;
+ boolean result = true;
+ result = result && getName()
+ .equals(other.getName());
+ result = result && (getVersion()
+ == other.getVersion());
+ return result;
}
@java.lang.Override
@@ -2279,7 +2340,7 @@ public int hashCode() {
return memoizedHashCode;
}
int hash = 41;
- hash = (19 * hash) + getDescriptor().hashCode();
+ hash = (19 * hash) + getDescriptorForType().hashCode();
hash = (37 * hash) + NAME_FIELD_NUMBER;
hash = (53 * hash) + getName().hashCode();
hash = (37 * hash) + VERSION_FIELD_NUMBER;
@@ -2290,18 +2351,7 @@ public int hashCode() {
}
public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData parseFrom(
- java.nio.ByteBuffer data)
- throws com.google.protobuf.InvalidProtocolBufferException {
- return PARSER.parseFrom(data);
- }
- public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData parseFrom(
- java.nio.ByteBuffer data,
- com.google.protobuf.ExtensionRegistryLite extensionRegistry)
- throws com.google.protobuf.InvalidProtocolBufferException {
- return PARSER.parseFrom(data, extensionRegistry);
- }
- public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData parseFrom(
- com.google.protobuf.ByteString data)
+ com.google.protobuf.ByteString data)
throws com.google.protobuf.InvalidProtocolBufferException {
return PARSER.parseFrom(data);
}
@@ -2359,7 +2409,6 @@ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData p
.parseWithIOException(PARSER, input, extensionRegistry);
}
- @java.lang.Override
public Builder newBuilderForType() { return newBuilder(); }
public static Builder newBuilder() {
return DEFAULT_INSTANCE.toBuilder();
@@ -2367,7 +2416,6 @@ public static Builder newBuilder() {
public static Builder newBuilder(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData prototype) {
return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype);
}
- @java.lang.Override
public Builder toBuilder() {
return this == DEFAULT_INSTANCE
? new Builder() : new Builder().mergeFrom(this);
@@ -2394,7 +2442,6 @@ public static final class Builder extends
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_MetaData_descriptor;
}
- @java.lang.Override
protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
internalGetFieldAccessorTable() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_MetaData_fieldAccessorTable
@@ -2417,7 +2464,6 @@ private void maybeForceBuilderInitialization() {
.alwaysUseFieldBuilders) {
}
}
- @java.lang.Override
public Builder clear() {
super.clear();
name_ = "";
@@ -2427,18 +2473,15 @@ public Builder clear() {
return this;
}
- @java.lang.Override
public com.google.protobuf.Descriptors.Descriptor
getDescriptorForType() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_MetaData_descriptor;
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData getDefaultInstanceForType() {
return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData.getDefaultInstance();
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData build() {
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData result = buildPartial();
if (!result.isInitialized()) {
@@ -2447,7 +2490,6 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData build()
return result;
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData buildPartial() {
com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData result = new com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData(this);
result.name_ = name_;
@@ -2456,39 +2498,32 @@ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData buildPar
return result;
}
- @java.lang.Override
public Builder clone() {
- return super.clone();
+ return (Builder) super.clone();
}
- @java.lang.Override
public Builder setField(
com.google.protobuf.Descriptors.FieldDescriptor field,
- java.lang.Object value) {
- return super.setField(field, value);
+ Object value) {
+ return (Builder) super.setField(field, value);
}
- @java.lang.Override
public Builder clearField(
com.google.protobuf.Descriptors.FieldDescriptor field) {
- return super.clearField(field);
+ return (Builder) super.clearField(field);
}
- @java.lang.Override
public Builder clearOneof(
com.google.protobuf.Descriptors.OneofDescriptor oneof) {
- return super.clearOneof(oneof);
+ return (Builder) super.clearOneof(oneof);
}
- @java.lang.Override
public Builder setRepeatedField(
com.google.protobuf.Descriptors.FieldDescriptor field,
- int index, java.lang.Object value) {
- return super.setRepeatedField(field, index, value);
+ int index, Object value) {
+ return (Builder) super.setRepeatedField(field, index, value);
}
- @java.lang.Override
public Builder addRepeatedField(
com.google.protobuf.Descriptors.FieldDescriptor field,
- java.lang.Object value) {
- return super.addRepeatedField(field, value);
+ Object value) {
+ return (Builder) super.addRepeatedField(field, value);
}
- @java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData) {
return mergeFrom((com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData)other);
@@ -2507,17 +2542,14 @@ public Builder mergeFrom(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto
if (other.getVersion() != 0) {
setVersion(other.getVersion());
}
- this.mergeUnknownFields(other.unknownFields);
onChanged();
return this;
}
- @java.lang.Override
public final boolean isInitialized() {
return true;
}
- @java.lang.Override
public Builder mergeFrom(
com.google.protobuf.CodedInputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -2538,8 +2570,7 @@ public Builder mergeFrom(
private java.lang.Object name_ = "";
/**
- * string name = 1;
- * @return The name.
+ * optional string name = 1;
*/
public java.lang.String getName() {
java.lang.Object ref = name_;
@@ -2554,8 +2585,7 @@ public java.lang.String getName() {
}
}
/**
- * string name = 1;
- * @return The bytes for name.
+ * optional string name = 1;
*/
public com.google.protobuf.ByteString
getNameBytes() {
@@ -2571,9 +2601,7 @@ public java.lang.String getName() {
}
}
/**
- * string name = 1;
- * @param value The name to set.
- * @return This builder for chaining.
+ * optional string name = 1;
*/
public Builder setName(
java.lang.String value) {
@@ -2586,8 +2614,7 @@ public Builder setName(
return this;
}
/**
- * string name = 1;
- * @return This builder for chaining.
+ * optional string name = 1;
*/
public Builder clearName() {
@@ -2596,9 +2623,7 @@ public Builder clearName() {
return this;
}
/**
- * string name = 1;
- * @param value The bytes for name to set.
- * @return This builder for chaining.
+ * optional string name = 1;
*/
public Builder setNameBytes(
com.google.protobuf.ByteString value) {
@@ -2614,17 +2639,13 @@ public Builder setNameBytes(
private int version_ ;
/**
- * int32 version = 2;
- * @return The version.
+ * optional int32 version = 2;
*/
- @java.lang.Override
public int getVersion() {
return version_;
}
/**
- * int32 version = 2;
- * @param value The version to set.
- * @return This builder for chaining.
+ * optional int32 version = 2;
*/
public Builder setVersion(int value) {
@@ -2633,8 +2654,7 @@ public Builder setVersion(int value) {
return this;
}
/**
- * int32 version = 2;
- * @return This builder for chaining.
+ * optional int32 version = 2;
*/
public Builder clearVersion() {
@@ -2642,16 +2662,14 @@ public Builder clearVersion() {
onChanged();
return this;
}
- @java.lang.Override
public final Builder setUnknownFields(
final com.google.protobuf.UnknownFieldSet unknownFields) {
- return super.setUnknownFields(unknownFields);
+ return this;
}
- @java.lang.Override
public final Builder mergeUnknownFields(
final com.google.protobuf.UnknownFieldSet unknownFields) {
- return super.mergeUnknownFields(unknownFields);
+ return this;
}
@@ -2670,12 +2688,11 @@ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData g
private static final com.google.protobuf.Parser
PARSER = new com.google.protobuf.AbstractParser() {
- @java.lang.Override
public MetaData parsePartialFrom(
com.google.protobuf.CodedInputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
throws com.google.protobuf.InvalidProtocolBufferException {
- return new MetaData(input, extensionRegistry);
+ return new MetaData(input, extensionRegistry);
}
};
@@ -2688,82 +2705,862 @@ public com.google.protobuf.Parser getParserForType() {
return PARSER;
}
- @java.lang.Override
public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.MetaData getDefaultInstanceForType() {
return DEFAULT_INSTANCE;
}
}
- private static final com.google.protobuf.Descriptors.Descriptor
- internal_static_FloatTensor_descriptor;
- private static final
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
- internal_static_FloatTensor_fieldAccessorTable;
- private static final com.google.protobuf.Descriptors.Descriptor
- internal_static_TensorMap_descriptor;
- private static final
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
- internal_static_TensorMap_fieldAccessorTable;
- private static final com.google.protobuf.Descriptors.Descriptor
- internal_static_TensorMap_TensorMapEntry_descriptor;
- private static final
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
- internal_static_TensorMap_TensorMapEntry_fieldAccessorTable;
- private static final com.google.protobuf.Descriptors.Descriptor
- internal_static_MetaData_descriptor;
- private static final
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
- internal_static_MetaData_fieldAccessorTable;
+ public interface EncryptedTensorOrBuilder extends
+ // @@protoc_insertion_point(interface_extends:EncryptedTensor)
+ com.google.protobuf.MessageOrBuilder {
- public static com.google.protobuf.Descriptors.FileDescriptor
- getDescriptor() {
- return descriptor;
+ /**
+ * repeated int32 shape = 1;
+ */
+ java.util.List getShapeList();
+ /**
+ * repeated int32 shape = 1;
+ */
+ int getShapeCount();
+ /**
+ * repeated int32 shape = 1;
+ */
+ int getShape(int index);
+
+ /**
+ * optional bytes tensor = 2;
+ */
+ com.google.protobuf.ByteString getTensor();
+
+ /**
+ * optional string dtype = 3;
+ */
+ java.lang.String getDtype();
+ /**
+ * optional string dtype = 3;
+ */
+ com.google.protobuf.ByteString
+ getDtypeBytes();
}
- private static com.google.protobuf.Descriptors.FileDescriptor
- descriptor;
- static {
- java.lang.String[] descriptorData = {
- "\n\rfl_base.proto\";\n\013FloatTensor\022\r\n\005shape\030" +
- "\001 \003(\005\022\016\n\006tensor\030\002 \003(\002\022\r\n\005dtype\030\003 \001(\t\"\226\001\n" +
- "\tTensorMap\022\033\n\010metaData\030\001 \001(\0132\t.MetaData\022" +
- ",\n\ttensorMap\030\002 \003(\0132\031.TensorMap.TensorMap" +
- "Entry\032>\n\016TensorMapEntry\022\013\n\003key\030\001 \001(\t\022\033\n\005" +
- "value\030\002 \001(\0132\014.FloatTensor:\0028\001\")\n\010MetaDat" +
- "a\022\014\n\004name\030\001 \001(\t\022\017\n\007version\030\002 \001(\005*H\n\006SIGN" +
- "AL\022\013\n\007SUCCESS\020\000\022\010\n\004WAIT\020\001\022\013\n\007TIMEOUT\020\002\022\017" +
- "\n\013EMPTY_INPUT\020\003\022\t\n\005ERROR\020\004B:\n+com.intel." +
- "analytics.bigdl.ppml.fl.generatedB\013FlBas" +
- "eProtob\006proto3"
- };
- descriptor = com.google.protobuf.Descriptors.FileDescriptor
- .internalBuildGeneratedFileFrom(descriptorData,
- new com.google.protobuf.Descriptors.FileDescriptor[] {
- });
- internal_static_FloatTensor_descriptor =
- getDescriptor().getMessageTypes().get(0);
- internal_static_FloatTensor_fieldAccessorTable = new
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
- internal_static_FloatTensor_descriptor,
- new java.lang.String[] { "Shape", "Tensor", "Dtype", });
- internal_static_TensorMap_descriptor =
- getDescriptor().getMessageTypes().get(1);
- internal_static_TensorMap_fieldAccessorTable = new
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
- internal_static_TensorMap_descriptor,
- new java.lang.String[] { "MetaData", "TensorMap", });
- internal_static_TensorMap_TensorMapEntry_descriptor =
- internal_static_TensorMap_descriptor.getNestedTypes().get(0);
- internal_static_TensorMap_TensorMapEntry_fieldAccessorTable = new
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
- internal_static_TensorMap_TensorMapEntry_descriptor,
- new java.lang.String[] { "Key", "Value", });
- internal_static_MetaData_descriptor =
- getDescriptor().getMessageTypes().get(2);
- internal_static_MetaData_fieldAccessorTable = new
- com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
- internal_static_MetaData_descriptor,
- new java.lang.String[] { "Name", "Version", });
+ /**
+ * Protobuf type {@code EncryptedTensor}
+ */
+ public static final class EncryptedTensor extends
+ com.google.protobuf.GeneratedMessageV3 implements
+ // @@protoc_insertion_point(message_implements:EncryptedTensor)
+ EncryptedTensorOrBuilder {
+ // Use EncryptedTensor.newBuilder() to construct.
+ private EncryptedTensor(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
+ super(builder);
+ }
+ private EncryptedTensor() {
+ shape_ = java.util.Collections.emptyList();
+ tensor_ = com.google.protobuf.ByteString.EMPTY;
+ dtype_ = "";
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return com.google.protobuf.UnknownFieldSet.getDefaultInstance();
+ }
+ private EncryptedTensor(
+ com.google.protobuf.CodedInputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ this();
+ int mutable_bitField0_ = 0;
+ try {
+ boolean done = false;
+ while (!done) {
+ int tag = input.readTag();
+ switch (tag) {
+ case 0:
+ done = true;
+ break;
+ default: {
+ if (!input.skipField(tag)) {
+ done = true;
+ }
+ break;
+ }
+ case 8: {
+ if (!((mutable_bitField0_ & 0x00000001) == 0x00000001)) {
+ shape_ = new java.util.ArrayList();
+ mutable_bitField0_ |= 0x00000001;
+ }
+ shape_.add(input.readInt32());
+ break;
+ }
+ case 10: {
+ int length = input.readRawVarint32();
+ int limit = input.pushLimit(length);
+ if (!((mutable_bitField0_ & 0x00000001) == 0x00000001) && input.getBytesUntilLimit() > 0) {
+ shape_ = new java.util.ArrayList();
+ mutable_bitField0_ |= 0x00000001;
+ }
+ while (input.getBytesUntilLimit() > 0) {
+ shape_.add(input.readInt32());
+ }
+ input.popLimit(limit);
+ break;
+ }
+ case 18: {
+
+ tensor_ = input.readBytes();
+ break;
+ }
+ case 26: {
+ java.lang.String s = input.readStringRequireUtf8();
+
+ dtype_ = s;
+ break;
+ }
+ }
+ }
+ } catch (com.google.protobuf.InvalidProtocolBufferException e) {
+ throw e.setUnfinishedMessage(this);
+ } catch (java.io.IOException e) {
+ throw new com.google.protobuf.InvalidProtocolBufferException(
+ e).setUnfinishedMessage(this);
+ } finally {
+ if (((mutable_bitField0_ & 0x00000001) == 0x00000001)) {
+ shape_ = java.util.Collections.unmodifiableList(shape_);
+ }
+ makeExtensionsImmutable();
+ }
+ }
+ public static final com.google.protobuf.Descriptors.Descriptor
+ getDescriptor() {
+ return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_EncryptedTensor_descriptor;
+ }
+
+ protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internalGetFieldAccessorTable() {
+ return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_EncryptedTensor_fieldAccessorTable
+ .ensureFieldAccessorsInitialized(
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.class, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.Builder.class);
+ }
+
+ private int bitField0_;
+ public static final int SHAPE_FIELD_NUMBER = 1;
+ private java.util.List shape_;
+ /**
+ * repeated int32 shape = 1;
+ */
+ public java.util.List
+ getShapeList() {
+ return shape_;
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public int getShapeCount() {
+ return shape_.size();
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public int getShape(int index) {
+ return shape_.get(index);
+ }
+ private int shapeMemoizedSerializedSize = -1;
+
+ public static final int TENSOR_FIELD_NUMBER = 2;
+ private com.google.protobuf.ByteString tensor_;
+ /**
+ * optional bytes tensor = 2;
+ */
+ public com.google.protobuf.ByteString getTensor() {
+ return tensor_;
+ }
+
+ public static final int DTYPE_FIELD_NUMBER = 3;
+ private volatile java.lang.Object dtype_;
+ /**
+ * optional string dtype = 3;
+ */
+ public java.lang.String getDtype() {
+ java.lang.Object ref = dtype_;
+ if (ref instanceof java.lang.String) {
+ return (java.lang.String) ref;
+ } else {
+ com.google.protobuf.ByteString bs =
+ (com.google.protobuf.ByteString) ref;
+ java.lang.String s = bs.toStringUtf8();
+ dtype_ = s;
+ return s;
+ }
+ }
+ /**
+ * optional string dtype = 3;
+ */
+ public com.google.protobuf.ByteString
+ getDtypeBytes() {
+ java.lang.Object ref = dtype_;
+ if (ref instanceof java.lang.String) {
+ com.google.protobuf.ByteString b =
+ com.google.protobuf.ByteString.copyFromUtf8(
+ (java.lang.String) ref);
+ dtype_ = b;
+ return b;
+ } else {
+ return (com.google.protobuf.ByteString) ref;
+ }
+ }
+
+ private byte memoizedIsInitialized = -1;
+ public final boolean isInitialized() {
+ byte isInitialized = memoizedIsInitialized;
+ if (isInitialized == 1) return true;
+ if (isInitialized == 0) return false;
+
+ memoizedIsInitialized = 1;
+ return true;
+ }
+
+ public void writeTo(com.google.protobuf.CodedOutputStream output)
+ throws java.io.IOException {
+ getSerializedSize();
+ if (getShapeList().size() > 0) {
+ output.writeUInt32NoTag(10);
+ output.writeUInt32NoTag(shapeMemoizedSerializedSize);
+ }
+ for (int i = 0; i < shape_.size(); i++) {
+ output.writeInt32NoTag(shape_.get(i));
+ }
+ if (!tensor_.isEmpty()) {
+ output.writeBytes(2, tensor_);
+ }
+ if (!getDtypeBytes().isEmpty()) {
+ com.google.protobuf.GeneratedMessageV3.writeString(output, 3, dtype_);
+ }
+ }
+
+ public int getSerializedSize() {
+ int size = memoizedSize;
+ if (size != -1) return size;
+
+ size = 0;
+ {
+ int dataSize = 0;
+ for (int i = 0; i < shape_.size(); i++) {
+ dataSize += com.google.protobuf.CodedOutputStream
+ .computeInt32SizeNoTag(shape_.get(i));
+ }
+ size += dataSize;
+ if (!getShapeList().isEmpty()) {
+ size += 1;
+ size += com.google.protobuf.CodedOutputStream
+ .computeInt32SizeNoTag(dataSize);
+ }
+ shapeMemoizedSerializedSize = dataSize;
+ }
+ if (!tensor_.isEmpty()) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeBytesSize(2, tensor_);
+ }
+ if (!getDtypeBytes().isEmpty()) {
+ size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, dtype_);
+ }
+ memoizedSize = size;
+ return size;
+ }
+
+ private static final long serialVersionUID = 0L;
+ @java.lang.Override
+ public boolean equals(final java.lang.Object obj) {
+ if (obj == this) {
+ return true;
+ }
+ if (!(obj instanceof com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor)) {
+ return super.equals(obj);
+ }
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor other = (com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor) obj;
+
+ boolean result = true;
+ result = result && getShapeList()
+ .equals(other.getShapeList());
+ result = result && getTensor()
+ .equals(other.getTensor());
+ result = result && getDtype()
+ .equals(other.getDtype());
+ return result;
+ }
+
+ @java.lang.Override
+ public int hashCode() {
+ if (memoizedHashCode != 0) {
+ return memoizedHashCode;
+ }
+ int hash = 41;
+ hash = (19 * hash) + getDescriptorForType().hashCode();
+ if (getShapeCount() > 0) {
+ hash = (37 * hash) + SHAPE_FIELD_NUMBER;
+ hash = (53 * hash) + getShapeList().hashCode();
+ }
+ hash = (37 * hash) + TENSOR_FIELD_NUMBER;
+ hash = (53 * hash) + getTensor().hashCode();
+ hash = (37 * hash) + DTYPE_FIELD_NUMBER;
+ hash = (53 * hash) + getDtype().hashCode();
+ hash = (29 * hash) + unknownFields.hashCode();
+ memoizedHashCode = hash;
+ return hash;
+ }
+
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(
+ com.google.protobuf.ByteString data)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(
+ com.google.protobuf.ByteString data,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data, extensionRegistry);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(byte[] data)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(
+ byte[] data,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data, extensionRegistry);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(java.io.InputStream input)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(
+ java.io.InputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input, extensionRegistry);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseDelimitedFrom(java.io.InputStream input)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseDelimitedWithIOException(PARSER, input);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseDelimitedFrom(
+ java.io.InputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseDelimitedWithIOException(PARSER, input, extensionRegistry);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(
+ com.google.protobuf.CodedInputStream input)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input);
+ }
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parseFrom(
+ com.google.protobuf.CodedInputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input, extensionRegistry);
+ }
+
+ public Builder newBuilderForType() { return newBuilder(); }
+ public static Builder newBuilder() {
+ return DEFAULT_INSTANCE.toBuilder();
+ }
+ public static Builder newBuilder(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor prototype) {
+ return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype);
+ }
+ public Builder toBuilder() {
+ return this == DEFAULT_INSTANCE
+ ? new Builder() : new Builder().mergeFrom(this);
+ }
+
+ @java.lang.Override
+ protected Builder newBuilderForType(
+ com.google.protobuf.GeneratedMessageV3.BuilderParent parent) {
+ Builder builder = new Builder(parent);
+ return builder;
+ }
+ /**
+ * Protobuf type {@code EncryptedTensor}
+ */
+ public static final class Builder extends
+ com.google.protobuf.GeneratedMessageV3.Builder implements
+ // @@protoc_insertion_point(builder_implements:EncryptedTensor)
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensorOrBuilder {
+ public static final com.google.protobuf.Descriptors.Descriptor
+ getDescriptor() {
+ return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_EncryptedTensor_descriptor;
+ }
+
+ protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internalGetFieldAccessorTable() {
+ return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_EncryptedTensor_fieldAccessorTable
+ .ensureFieldAccessorsInitialized(
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.class, com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.Builder.class);
+ }
+
+ // Construct using com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.newBuilder()
+ private Builder() {
+ maybeForceBuilderInitialization();
+ }
+
+ private Builder(
+ com.google.protobuf.GeneratedMessageV3.BuilderParent parent) {
+ super(parent);
+ maybeForceBuilderInitialization();
+ }
+ private void maybeForceBuilderInitialization() {
+ if (com.google.protobuf.GeneratedMessageV3
+ .alwaysUseFieldBuilders) {
+ }
+ }
+ public Builder clear() {
+ super.clear();
+ shape_ = java.util.Collections.emptyList();
+ bitField0_ = (bitField0_ & ~0x00000001);
+ tensor_ = com.google.protobuf.ByteString.EMPTY;
+
+ dtype_ = "";
+
+ return this;
+ }
+
+ public com.google.protobuf.Descriptors.Descriptor
+ getDescriptorForType() {
+ return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.internal_static_EncryptedTensor_descriptor;
+ }
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getDefaultInstanceForType() {
+ return com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.getDefaultInstance();
+ }
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor build() {
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor result = buildPartial();
+ if (!result.isInitialized()) {
+ throw newUninitializedMessageException(result);
+ }
+ return result;
+ }
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor buildPartial() {
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor result = new com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor(this);
+ int from_bitField0_ = bitField0_;
+ int to_bitField0_ = 0;
+ if (((bitField0_ & 0x00000001) == 0x00000001)) {
+ shape_ = java.util.Collections.unmodifiableList(shape_);
+ bitField0_ = (bitField0_ & ~0x00000001);
+ }
+ result.shape_ = shape_;
+ result.tensor_ = tensor_;
+ result.dtype_ = dtype_;
+ result.bitField0_ = to_bitField0_;
+ onBuilt();
+ return result;
+ }
+
+ public Builder clone() {
+ return (Builder) super.clone();
+ }
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ Object value) {
+ return (Builder) super.setField(field, value);
+ }
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return (Builder) super.clearField(field);
+ }
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return (Builder) super.clearOneof(oneof);
+ }
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, Object value) {
+ return (Builder) super.setRepeatedField(field, index, value);
+ }
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ Object value) {
+ return (Builder) super.addRepeatedField(field, value);
+ }
+ public Builder mergeFrom(com.google.protobuf.Message other) {
+ if (other instanceof com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor) {
+ return mergeFrom((com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor)other);
+ } else {
+ super.mergeFrom(other);
+ return this;
+ }
+ }
+
+ public Builder mergeFrom(com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor other) {
+ if (other == com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor.getDefaultInstance()) return this;
+ if (!other.shape_.isEmpty()) {
+ if (shape_.isEmpty()) {
+ shape_ = other.shape_;
+ bitField0_ = (bitField0_ & ~0x00000001);
+ } else {
+ ensureShapeIsMutable();
+ shape_.addAll(other.shape_);
+ }
+ onChanged();
+ }
+ if (other.getTensor() != com.google.protobuf.ByteString.EMPTY) {
+ setTensor(other.getTensor());
+ }
+ if (!other.getDtype().isEmpty()) {
+ dtype_ = other.dtype_;
+ onChanged();
+ }
+ onChanged();
+ return this;
+ }
+
+ public final boolean isInitialized() {
+ return true;
+ }
+
+ public Builder mergeFrom(
+ com.google.protobuf.CodedInputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor parsedMessage = null;
+ try {
+ parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);
+ } catch (com.google.protobuf.InvalidProtocolBufferException e) {
+ parsedMessage = (com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor) e.getUnfinishedMessage();
+ throw e.unwrapIOException();
+ } finally {
+ if (parsedMessage != null) {
+ mergeFrom(parsedMessage);
+ }
+ }
+ return this;
+ }
+ private int bitField0_;
+
+ private java.util.List shape_ = java.util.Collections.emptyList();
+ private void ensureShapeIsMutable() {
+ if (!((bitField0_ & 0x00000001) == 0x00000001)) {
+ shape_ = new java.util.ArrayList(shape_);
+ bitField0_ |= 0x00000001;
+ }
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public java.util.List
+ getShapeList() {
+ return java.util.Collections.unmodifiableList(shape_);
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public int getShapeCount() {
+ return shape_.size();
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public int getShape(int index) {
+ return shape_.get(index);
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public Builder setShape(
+ int index, int value) {
+ ensureShapeIsMutable();
+ shape_.set(index, value);
+ onChanged();
+ return this;
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public Builder addShape(int value) {
+ ensureShapeIsMutable();
+ shape_.add(value);
+ onChanged();
+ return this;
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public Builder addAllShape(
+ java.lang.Iterable extends java.lang.Integer> values) {
+ ensureShapeIsMutable();
+ com.google.protobuf.AbstractMessageLite.Builder.addAll(
+ values, shape_);
+ onChanged();
+ return this;
+ }
+ /**
+ * repeated int32 shape = 1;
+ */
+ public Builder clearShape() {
+ shape_ = java.util.Collections.emptyList();
+ bitField0_ = (bitField0_ & ~0x00000001);
+ onChanged();
+ return this;
+ }
+
+ private com.google.protobuf.ByteString tensor_ = com.google.protobuf.ByteString.EMPTY;
+ /**
+ * optional bytes tensor = 2;
+ */
+ public com.google.protobuf.ByteString getTensor() {
+ return tensor_;
+ }
+ /**
+ * optional bytes tensor = 2;
+ */
+ public Builder setTensor(com.google.protobuf.ByteString value) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
+ tensor_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * optional bytes tensor = 2;
+ */
+ public Builder clearTensor() {
+
+ tensor_ = getDefaultInstance().getTensor();
+ onChanged();
+ return this;
+ }
+
+ private java.lang.Object dtype_ = "";
+ /**
+ * optional string dtype = 3;
+ */
+ public java.lang.String getDtype() {
+ java.lang.Object ref = dtype_;
+ if (!(ref instanceof java.lang.String)) {
+ com.google.protobuf.ByteString bs =
+ (com.google.protobuf.ByteString) ref;
+ java.lang.String s = bs.toStringUtf8();
+ dtype_ = s;
+ return s;
+ } else {
+ return (java.lang.String) ref;
+ }
+ }
+ /**
+ * optional string dtype = 3;
+ */
+ public com.google.protobuf.ByteString
+ getDtypeBytes() {
+ java.lang.Object ref = dtype_;
+ if (ref instanceof String) {
+ com.google.protobuf.ByteString b =
+ com.google.protobuf.ByteString.copyFromUtf8(
+ (java.lang.String) ref);
+ dtype_ = b;
+ return b;
+ } else {
+ return (com.google.protobuf.ByteString) ref;
+ }
+ }
+ /**
+ * optional string dtype = 3;
+ */
+ public Builder setDtype(
+ java.lang.String value) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
+ dtype_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * optional string dtype = 3;
+ */
+ public Builder clearDtype() {
+
+ dtype_ = getDefaultInstance().getDtype();
+ onChanged();
+ return this;
+ }
+ /**
+ * optional string dtype = 3;
+ */
+ public Builder setDtypeBytes(
+ com.google.protobuf.ByteString value) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ checkByteStringIsUtf8(value);
+
+ dtype_ = value;
+ onChanged();
+ return this;
+ }
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return this;
+ }
+
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return this;
+ }
+
+
+ // @@protoc_insertion_point(builder_scope:EncryptedTensor)
+ }
+
+ // @@protoc_insertion_point(class_scope:EncryptedTensor)
+ private static final com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor DEFAULT_INSTANCE;
+ static {
+ DEFAULT_INSTANCE = new com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor();
+ }
+
+ public static com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getDefaultInstance() {
+ return DEFAULT_INSTANCE;
+ }
+
+ private static final com.google.protobuf.Parser
+ PARSER = new com.google.protobuf.AbstractParser() {
+ public EncryptedTensor parsePartialFrom(
+ com.google.protobuf.CodedInputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return new EncryptedTensor(input, extensionRegistry);
+ }
+ };
+
+ public static com.google.protobuf.Parser parser() {
+ return PARSER;
+ }
+
+ @java.lang.Override
+ public com.google.protobuf.Parser getParserForType() {
+ return PARSER;
+ }
+
+ public com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.EncryptedTensor getDefaultInstanceForType() {
+ return DEFAULT_INSTANCE;
+ }
+
+ }
+
+ private static final com.google.protobuf.Descriptors.Descriptor
+ internal_static_FloatTensor_descriptor;
+ private static final
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internal_static_FloatTensor_fieldAccessorTable;
+ private static final com.google.protobuf.Descriptors.Descriptor
+ internal_static_TensorMap_descriptor;
+ private static final
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internal_static_TensorMap_fieldAccessorTable;
+ private static final com.google.protobuf.Descriptors.Descriptor
+ internal_static_TensorMap_TensorMapEntry_descriptor;
+ private static final
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internal_static_TensorMap_TensorMapEntry_fieldAccessorTable;
+ private static final com.google.protobuf.Descriptors.Descriptor
+ internal_static_TensorMap_EncryptedTensorMapEntry_descriptor;
+ private static final
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internal_static_TensorMap_EncryptedTensorMapEntry_fieldAccessorTable;
+ private static final com.google.protobuf.Descriptors.Descriptor
+ internal_static_MetaData_descriptor;
+ private static final
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internal_static_MetaData_fieldAccessorTable;
+ private static final com.google.protobuf.Descriptors.Descriptor
+ internal_static_EncryptedTensor_descriptor;
+ private static final
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internal_static_EncryptedTensor_fieldAccessorTable;
+
+ public static com.google.protobuf.Descriptors.FileDescriptor
+ getDescriptor() {
+ return descriptor;
+ }
+ private static com.google.protobuf.Descriptors.FileDescriptor
+ descriptor;
+ static {
+ java.lang.String[] descriptorData = {
+ "\n\030main/proto/fl_base.proto\";\n\013FloatTenso" +
+ "r\022\r\n\005shape\030\001 \003(\005\022\016\n\006tensor\030\002 \003(\002\022\r\n\005dtyp" +
+ "e\030\003 \001(\t\"\243\002\n\tTensorMap\022\033\n\010metaData\030\001 \001(\0132" +
+ "\t.MetaData\022,\n\ttensorMap\030\002 \003(\0132\031.TensorMa" +
+ "p.TensorMapEntry\022>\n\022encryptedTensorMap\030\003" +
+ " \003(\0132\".TensorMap.EncryptedTensorMapEntry" +
+ "\032>\n\016TensorMapEntry\022\013\n\003key\030\001 \001(\t\022\033\n\005value" +
+ "\030\002 \001(\0132\014.FloatTensor:\0028\001\032K\n\027EncryptedTen" +
+ "sorMapEntry\022\013\n\003key\030\001 \001(\t\022\037\n\005value\030\002 \001(\0132" +
+ "\020.EncryptedTensor:\0028\001\")\n\010MetaData\022\014\n\004nam",
+ "e\030\001 \001(\t\022\017\n\007version\030\002 \001(\005\"?\n\017EncryptedTen" +
+ "sor\022\r\n\005shape\030\001 \003(\005\022\016\n\006tensor\030\002 \001(\014\022\r\n\005dt" +
+ "ype\030\003 \001(\t*H\n\006SIGNAL\022\013\n\007SUCCESS\020\000\022\010\n\004WAIT" +
+ "\020\001\022\013\n\007TIMEOUT\020\002\022\017\n\013EMPTY_INPUT\020\003\022\t\n\005ERRO" +
+ "R\020\004B:\n+com.intel.analytics.bigdl.ppml.fl" +
+ ".generatedB\013FlBaseProtob\006proto3"
+ };
+ com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner =
+ new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() {
+ public com.google.protobuf.ExtensionRegistry assignDescriptors(
+ com.google.protobuf.Descriptors.FileDescriptor root) {
+ descriptor = root;
+ return null;
+ }
+ };
+ com.google.protobuf.Descriptors.FileDescriptor
+ .internalBuildGeneratedFileFrom(descriptorData,
+ new com.google.protobuf.Descriptors.FileDescriptor[] {
+ }, assigner);
+ internal_static_FloatTensor_descriptor =
+ getDescriptor().getMessageTypes().get(0);
+ internal_static_FloatTensor_fieldAccessorTable = new
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
+ internal_static_FloatTensor_descriptor,
+ new java.lang.String[] { "Shape", "Tensor", "Dtype", });
+ internal_static_TensorMap_descriptor =
+ getDescriptor().getMessageTypes().get(1);
+ internal_static_TensorMap_fieldAccessorTable = new
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
+ internal_static_TensorMap_descriptor,
+ new java.lang.String[] { "MetaData", "TensorMap", "EncryptedTensorMap", });
+ internal_static_TensorMap_TensorMapEntry_descriptor =
+ internal_static_TensorMap_descriptor.getNestedTypes().get(0);
+ internal_static_TensorMap_TensorMapEntry_fieldAccessorTable = new
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
+ internal_static_TensorMap_TensorMapEntry_descriptor,
+ new java.lang.String[] { "Key", "Value", });
+ internal_static_TensorMap_EncryptedTensorMapEntry_descriptor =
+ internal_static_TensorMap_descriptor.getNestedTypes().get(1);
+ internal_static_TensorMap_EncryptedTensorMapEntry_fieldAccessorTable = new
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
+ internal_static_TensorMap_EncryptedTensorMapEntry_descriptor,
+ new java.lang.String[] { "Key", "Value", });
+ internal_static_MetaData_descriptor =
+ getDescriptor().getMessageTypes().get(2);
+ internal_static_MetaData_fieldAccessorTable = new
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
+ internal_static_MetaData_descriptor,
+ new java.lang.String[] { "Name", "Version", });
+ internal_static_EncryptedTensor_descriptor =
+ getDescriptor().getMessageTypes().get(3);
+ internal_static_EncryptedTensor_fieldAccessorTable = new
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
+ internal_static_EncryptedTensor_descriptor,
+ new java.lang.String[] { "Shape", "Tensor", "Dtype", });
}
// @@protoc_insertion_point(outer_class_scope)
diff --git a/scala/ppml/src/main/java/com/intel/analytics/bigdl/ppml/fl/vfl/NNStub.java b/scala/ppml/src/main/java/com/intel/analytics/bigdl/ppml/fl/vfl/NNStub.java
index da34a8ab431..defa50af8f3 100644
--- a/scala/ppml/src/main/java/com/intel/analytics/bigdl/ppml/fl/vfl/NNStub.java
+++ b/scala/ppml/src/main/java/com/intel/analytics/bigdl/ppml/fl/vfl/NNStub.java
@@ -16,6 +16,8 @@
package com.intel.analytics.bigdl.ppml.fl.vfl;
+import com.google.protobuf.ByteString;
+import com.intel.analytics.bigdl.ckks.CKKS;
import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto.*;
import com.intel.analytics.bigdl.ppml.fl.generated.NNServiceProto.*;
import com.intel.analytics.bigdl.ppml.fl.generated.NNServiceGrpc;
@@ -23,48 +25,137 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
public class NNStub {
private static final Logger logger = LoggerFactory.getLogger(NNStub.class);
private static NNServiceGrpc.NNServiceBlockingStub stub;
Integer clientID;
+ protected CKKS ckks;
+ protected long encrytorPtr;
public NNStub(Channel channel, Integer clientID) {
this.clientID = clientID;
stub = NNServiceGrpc.newBlockingStub(channel);
}
- public TrainResponse train(TensorMap data, String algorithm) {
- TrainRequest trainRequest = TrainRequest
- .newBuilder()
- .setData(data)
- .setClientuuid(clientID)
- .setAlgorithm(algorithm)
- .build();
- logDebugMessage(data);
- return stub.train(trainRequest);
+ public NNStub(Channel channel, Integer clientID, byte[][] secrets) {
+ this.clientID = clientID;
+ stub = NNServiceGrpc.newBlockingStub(channel);
+ ckks = new CKKS();
+ encrytorPtr = ckks.createCkksEncryptor(secrets);
+ }
+
+ private EncryptedTensor encrypt(FloatTensor ft) {
+ float[] array = new float[ft.getTensorCount()];
+ for(int i = 0; i < ft.getTensorCount(); i++) {
+ array[i] = ft.getTensorList().get(i);
+ }
+ byte[] encryptedArray = ckks.ckksEncrypt(encrytorPtr, array);
+
+ EncryptedTensor et =
+ EncryptedTensor.newBuilder()
+ .addAllShape(ft.getShapeList())
+ .setTensor(ByteString.copyFrom(encryptedArray))
+ .build();
+ return et;
+ }
+
+ public TensorMap encrypt(TensorMap data) {
+ Map d = data.getTensorMapMap();
+
+ TensorMap.Builder encryptedDataBuilder = TensorMap.newBuilder()
+ .setMetaData(data.getMetaData());
+ for (Map.Entry fts: d.entrySet()){
+ encryptedDataBuilder.putEncryptedTensorMap(fts.getKey(),
+ encrypt(fts.getValue()));
+ }
+ return encryptedDataBuilder.build();
+ }
+
+ private FloatTensor decrypt(EncryptedTensor et) {
+ byte[] array = et.getTensor().toByteArray();
+ float[] decryptedArray = ckks.ckksDecrypt(encrytorPtr, array);
+
+ List floatList = new ArrayList(decryptedArray.length);
+ for (float v : decryptedArray) {
+ floatList.add(v);
+ }
+
+ FloatTensor ft =
+ FloatTensor.newBuilder()
+ .addAllShape(et.getShapeList())
+ .addAllTensor(floatList)
+ .build();
+ return ft;
+ }
+
+ public TensorMap decrypt(TensorMap data) {
+ Map d = data.getEncryptedTensorMapMap();
+
+ TensorMap.Builder encryptedDataBuilder = TensorMap.newBuilder()
+ .setMetaData(data.getMetaData());
+ for (Map.Entry ets: d.entrySet()){
+ encryptedDataBuilder.putTensorMap(ets.getKey(),
+ decrypt(ets.getValue()));
+ }
+ return encryptedDataBuilder.build();
+ }
+
+ public TensorMap train(TensorMap data, String algorithm) {
+ TrainRequest.Builder trainRequestBuilder = TrainRequest
+ .newBuilder()
+ .setClientuuid(clientID)
+ .setAlgorithm(algorithm);
+ if (null != ckks) {
+ TensorMap encryptedData = encrypt(data);
+ trainRequestBuilder.setData(encryptedData);
+ logDebugMessage(encryptedData);
+ return decrypt(stub.train(trainRequestBuilder.build()).getData());
+ } else {
+ trainRequestBuilder.setData(data);
+ logDebugMessage(data);
+ return stub.train(trainRequestBuilder.build()).getData();
+ }
}
public EvaluateResponse evaluate(TensorMap data, String algorithm, Boolean hasReturn) {
- EvaluateRequest evaluateRequest = EvaluateRequest
+ EvaluateRequest.Builder evaluateRequestBuilder = EvaluateRequest
.newBuilder()
- .setData(data)
.setReturn(hasReturn)
.setClientuuid(clientID)
- .setAlgorithm(algorithm)
- .build();
- logDebugMessage(data);
- return stub.evaluate(evaluateRequest);
+ .setAlgorithm(algorithm);
+ if (null != ckks) {
+// TODO: evaluate with CKKS
+// TensorMap encryptedData = encrypt(data);
+// evaluateRequestBuilder.setData(encryptedData);
+// logDebugMessage(encryptedData);
+ throw new UnsupportedOperationException("evaluate with CKKS is unspported.");
+ } else {
+ evaluateRequestBuilder.setData(data);
+ logDebugMessage(data);
+ }
+ return stub.evaluate(evaluateRequestBuilder.build());
}
- public PredictResponse predict(TensorMap data, String algorithm) {
- PredictRequest predictRequest = PredictRequest
+ public TensorMap predict(TensorMap data, String algorithm) {
+ PredictRequest.Builder predictRequestBuilder = PredictRequest
.newBuilder()
.setData(data)
.setClientuuid(clientID)
- .setAlgorithm(algorithm)
- .build();
- logDebugMessage(data);
- return stub.predict(predictRequest);
+ .setAlgorithm(algorithm);
+ if (null != ckks) {
+ TensorMap encryptedData = encrypt(data);
+ predictRequestBuilder.setData(encryptedData);
+ logDebugMessage(encryptedData);
+ return decrypt(stub.predict(predictRequestBuilder.build()).getData());
+ } else {
+ predictRequestBuilder.setData(data);
+ logDebugMessage(data);
+ return stub.predict(predictRequestBuilder.build()).getData();
+ }
}
private void logDebugMessage(TensorMap data) {
diff --git a/scala/ppml/src/main/proto/fl_base.proto b/scala/ppml/src/main/proto/fl_base.proto
index 7d9e2c879a8..fdc032fd78b 100644
--- a/scala/ppml/src/main/proto/fl_base.proto
+++ b/scala/ppml/src/main/proto/fl_base.proto
@@ -37,9 +37,16 @@ message FloatTensor {
message TensorMap {
MetaData metaData = 1;
map tensorMap = 2;
+ map encryptedTensorMap = 3;
}
//
message MetaData {
string name = 1;
int32 version = 2;
}
+
+message EncryptedTensor {
+ repeated int32 shape = 1;
+ bytes tensor = 2;
+ string dtype = 3;
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLClient.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLClient.scala
index b6d86a5cd30..3324e7c62bb 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLClient.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLClient.scala
@@ -68,6 +68,10 @@ class FLClient(val _args: Array[String]) extends GrpcClientBase(_args) {
fgbostStub = new FGBoostStub(channel, clientID)
}
+ def initCkks(secret: Array[Array[Byte]]): Unit = {
+ nnStub = new NNStub(channel, clientID, secret)
+ }
+
override def shutdown(): Unit = {
try channel.shutdown.awaitTermination(5, TimeUnit.SECONDS)
catch {
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLContext.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLContext.scala
index d6145db9c25..ef951a4a299 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLContext.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLContext.scala
@@ -16,6 +16,7 @@
package com.intel.analytics.bigdl.ppml.fl
+import com.intel.analytics.bigdl.ckks.CKKS
import com.intel.analytics.bigdl.dllib.utils.Engine
import org.apache.log4j.LogManager
import org.apache.spark.SparkConf
@@ -42,6 +43,14 @@ object FLContext {
flClient.psiSalt
}
+ def initCkks(secrets: Array[Array[Byte]]): Unit = {
+ flClient.initCkks(secrets)
+ }
+
+ def initCkks(secretsPath: String): Unit = {
+ flClient.initCkks(CKKS.loadSecret(secretsPath))
+ }
+
def initFLContext(id: Int, target: String = null): Unit = {
createSparkSession()
Engine.init
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLServer.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLServer.scala
index ae879f2bdc6..27444c4239b 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLServer.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/FLServer.scala
@@ -51,6 +51,8 @@ class FLServer private[ppml](val _args: Array[String] = null) extends GrpcServer
configPath = "ppml-conf.yaml"
var clientNum: Int = 1
val fgBoostConfig = new FLConfig()
+ var nnService: NNServiceImpl = null
+ var ckksSecretPath = ""
parseConfig()
def setClientNum(clientNum: Int): Unit = {
@@ -66,6 +68,7 @@ class FLServer private[ppml](val _args: Array[String] = null) extends GrpcServer
certChainFilePath = flHelper.certChainFilePath
privateKeyFilePath = flHelper.privateKeyFilePath
fgBoostConfig.setModelPath(flHelper.fgBoostServerModelPath)
+ ckksSecretPath = flHelper.ckksSercetPath
}
}
@@ -80,7 +83,20 @@ class FLServer private[ppml](val _args: Array[String] = null) extends GrpcServer
}
def addService(): Unit = {
serverServices.add(new PSIServiceImpl(clientNum))
- serverServices.add(new NNServiceImpl(clientNum))
+ if (nnService == null) {
+ nnService = new NNServiceImpl(clientNum)
+ }
+ if (ckksSecretPath.nonEmpty) {
+ nnService.initCkksAggregator(ckksSecretPath)
+ }
+ serverServices.add(nnService)
serverServices.add(new FGBoostServiceImpl(clientNum, fgBoostConfig))
}
+
+ private[bigdl] def setCkksAggregator(secret: Array[Array[Byte]]): Unit = {
+ if (nnService == null) {
+ nnService = new NNServiceImpl(clientNum)
+ }
+ nnService.initCkksAggregator(secret)
+ }
}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/NNModel.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/NNModel.scala
index 9d3be827b52..fb2be7f97f5 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/NNModel.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/NNModel.scala
@@ -46,6 +46,7 @@ abstract class NNModel() {
VFLTensorUtils.featureLabelToMiniBatch(xTrain, yTrain, batchSize),
VFLTensorUtils.featureLabelToMiniBatch(xValidate, yValidate, batchSize))
}
+
/**
*
* @param trainData DataFrame of training data
@@ -117,6 +118,7 @@ abstract class NNModel() {
def predict(x: Tensor[Float], batchSize: Int = 4): Array[Activity] = {
estimator.predict(VFLTensorUtils.featureLabelToMiniBatch(x, null, batchSize))
}
+
/**
*
* @param data DataFrame of prediction data
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/algorithms/VFLLogisticRegression.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/algorithms/VFLLogisticRegression.scala
index eb1b634510a..1ce13fc4082 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/algorithms/VFLLogisticRegression.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/algorithms/VFLLogisticRegression.scala
@@ -16,8 +16,11 @@
package com.intel.analytics.bigdl.ppml.fl.algorithms
+import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.dllib.nn.{Linear, Sequential}
import com.intel.analytics.bigdl.dllib.optim.Adam
+import com.intel.analytics.bigdl.dllib.tensor.Tensor
+import com.intel.analytics.bigdl.dllib.utils.Log4Error
import com.intel.analytics.bigdl.ppml.fl.NNModel
import com.intel.analytics.bigdl.ppml.fl.nn.VFLNNEstimator
import com.intel.analytics.bigdl.ppml.fl.utils.FLClientClosable
@@ -27,10 +30,17 @@ import com.intel.analytics.bigdl.ppml.fl.utils.FLClientClosable
* @param featureNum
* @param learningRate
*/
-class VFLLogisticRegression(featureNum: Int,
- learningRate: Float = 0.005f) extends NNModel() {
- val model = Sequential[Float]().add(Linear(featureNum, 1))
+class VFLLogisticRegression(featureNum: Int = -1,
+ learningRate: Float = 0.005f,
+ customModel: Module[Float] = null,
+ algorithm: String = "vfl_logistic_regression") extends NNModel() {
+ Log4Error.invalidInputError(featureNum != -1 || customModel != null,
+ "Either featureNum or customModel should be provided")
+ val clientModule = if (customModel == null) {
+ Linear[Float](featureNum, 1)
+ } else customModel
+ val model = Sequential[Float]().add(clientModule)
override val estimator = new VFLNNEstimator(
- "vfl_logistic_regression", model, new Adam(learningRate))
+ algorithm, model, new Adam(learningRate))
}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/base/Estimator.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/base/Estimator.scala
index 67b6f0240b9..961b40ce721 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/base/Estimator.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/base/Estimator.scala
@@ -18,6 +18,7 @@ package com.intel.analytics.bigdl.ppml.fl.base
import com.intel.analytics.bigdl.dllib.feature.dataset.{LocalDataSet, MiniBatch}
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity
+import com.intel.analytics.bigdl.dllib.tensor.Tensor
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -27,6 +28,7 @@ trait Estimator {
def getEvaluateResults(): Map[String, Array[Float]] = {
evaluateResults.map(v => (v._1, v._2.toArray)).toMap
}
+
def train(endEpoch: Int,
trainDataSet: LocalDataSet[MiniBatch[Float]],
valDataSet: LocalDataSet[MiniBatch[Float]]): Any
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/DataPreprocessing.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/DataPreprocessing.scala
new file mode 100644
index 00000000000..bd10f37062c
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/DataPreprocessing.scala
@@ -0,0 +1,342 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.ppml.fl.example.ckks
+
+import com.intel.analytics.bigdl.DataSet
+import com.intel.analytics.bigdl.dllib.feature.dataset.{DataSet, MiniBatch, Sample, SampleToMiniBatch, TensorSample}
+import com.intel.analytics.bigdl.dllib.tensor.Tensor
+import com.intel.analytics.bigdl.dllib.utils.T
+import org.apache.logging.log4j.Level
+import org.apache.logging.log4j.core.config.Configurator
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.functions.{array, col, udf}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+class DataPreprocessing(spark: SparkSession,
+ trainDataPath: String,
+ testDataPath: String,
+ clientId: Int) extends Serializable {
+ case class Record(
+ age: Int,
+ workclass: String,
+ fnlwgt: Int,
+ education: String,
+ education_num: Int,
+ marital_status: String,
+ occupation: String,
+ relationship: String,
+ race: String,
+ gender: String,
+ capital_gain: Int,
+ capital_loss: Int,
+ hours_per_week: Int,
+ native_country: String,
+ income_bracket: String
+ )
+ val batchSize = 8192
+ val modelType = "wide"
+
+ val recordSchema = StructType(Array(
+ StructField("age", IntegerType, false),
+ StructField("workclass", StringType, false),
+ StructField("fnlwgt", IntegerType, false),
+ StructField("education", StringType, false),
+ StructField("education_num", IntegerType, false),
+ StructField("marital_status", StringType, false),
+ StructField("occupation", StringType, false),
+ StructField("relationship", StringType, false),
+ StructField("race", StringType, false),
+ StructField("gender", StringType, false),
+ StructField("capital_gain", IntegerType, false),
+ StructField("capital_loss", IntegerType, false),
+ StructField("hours_per_week", IntegerType, false),
+ StructField("native_country", StringType, false),
+ StructField("income_bracket", StringType, false)
+ ))
+
+
+ case class RecordSample[T: ClassTag](sample: Sample[T])
+
+ Configurator.setLevel("org", Level.ERROR)
+
+ def loadCensusData():
+ (DataSet[MiniBatch[Float]], DataSet[MiniBatch[Float]]) = {
+ val training = spark.sparkContext
+ .textFile(trainDataPath)
+ .map(_.split(",").map(_.trim))
+ .filter(_.size == 15).map(array =>
+ Row(
+ array(0).toInt, array(1), array(2).toInt, array(3), array(4).toInt,
+ array(5), array(6), array(7), array(8), array(9),
+ array(10).toInt, array(11).toInt, array(12).toInt, array(13), array(14)
+ )
+ )
+
+ val validation = spark.sparkContext
+ .textFile(testDataPath)
+ .map(_.dropRight(1)) // remove dot at the end of each line in adult.test
+ .map(_.split(",").map(_.trim))
+ .filter(_.size == 15).map(array =>
+ Row(
+ array(0).toInt, array(1), array(2).toInt, array(3), array(4).toInt,
+ array(5), array(6), array(7), array(8), array(9),
+ array(10).toInt, array(11).toInt, array(12).toInt, array(13), array(14)
+ ))
+
+ val (trainDf, valDf) = (spark.createDataFrame(training, recordSchema),
+ spark.createDataFrame(validation, recordSchema))
+
+ println(trainDf.show(10))
+ val localColumnInfo = if (clientId == 1) {
+ ColumnFeatureInfo(
+ wideBaseCols = Array("edu", "occ", "age_bucket"),
+ wideBaseDims = Array(16, 1000, 11),
+ wideCrossCols = Array("edu_occ", "age_edu_occ"),
+ wideCrossDims = Array(1000, 1000),
+ indicatorCols = Array("work", "edu", "mari"),
+ indicatorDims = Array(9, 16, 7),
+ embedCols = Array("occ"),
+ embedInDims = Array(1000),
+ embedOutDims = Array(8),
+ continuousCols = Array("age", "education_num"))
+ } else {
+ ColumnFeatureInfo(
+ wideBaseCols = Array("rela", "work", "mari"),
+ wideBaseDims = Array(6, 9, 7),
+ indicatorCols = Array("rela"),
+ indicatorDims = Array(6),
+ // TODO: the error may well be the missed field here
+ continuousCols = Array("capital_gain",
+ "capital_loss", "hours_per_week"))
+ }
+
+
+ val isImplicit = false
+ val trainpairFeatureRdds =
+ assemblyFeature(isImplicit, trainDf, localColumnInfo, modelType)
+
+ val validationpairFeatureRdds =
+ assemblyFeature(isImplicit, valDf, localColumnInfo, modelType)
+
+ val trainDataset = DataSet.array(
+ trainpairFeatureRdds.map(_.sample).collect()) -> SampleToMiniBatch[Float](batchSize)
+ val validationDataset = DataSet.array(
+ validationpairFeatureRdds.map(_.sample).collect()) -> SampleToMiniBatch[Float](batchSize)
+ (trainDataset, validationDataset)
+ }
+
+ case class ColumnFeatureInfo(wideBaseCols: Array[String] = Array[String](),
+ wideBaseDims: Array[Int] = Array[Int](),
+ wideCrossCols: Array[String] = Array[String](),
+ wideCrossDims: Array[Int] = Array[Int](),
+ indicatorCols: Array[String] = Array[String](),
+ indicatorDims: Array[Int] = Array[Int](),
+ embedCols: Array[String] = Array[String](),
+ embedInDims: Array[Int] = Array[Int](),
+ embedOutDims: Array[Int] = Array[Int](),
+ continuousCols: Array[String] = Array[String](),
+ label: String = "label") extends Serializable {
+ override def toString: String = {
+ "wideBaseCols:" + wideBaseCols.mkString(",") + "\n" +
+ "wideBaseDims:" + wideBaseDims.mkString(",") + "\n" +
+ "wideCrossCols:" + wideCrossCols.mkString(",") + "\n" +
+ "wideCrossDims:" + wideCrossDims.mkString(",") + "\n" +
+ "indicatorCols:" + indicatorCols.mkString(",") + "\n" +
+ "indicatorDims:" + indicatorDims.mkString(",") + "\n" +
+ "embedCols:" + embedCols.mkString(",") + "\n" +
+ "embedInDims:" + embedInDims.mkString(",") + "\n" +
+ "embedOutDims:" + embedOutDims.mkString(",") + "\n" +
+ "continuousCols:" + continuousCols.mkString(",") + "\n" +
+ "label:" + label
+
+ }
+ }
+
+ def categoricalFromVocabList(vocabList: Array[String]): (String) => Int = {
+ val func = (sth: String) => {
+ val default: Int = 0
+ val start: Int = 1
+ if (vocabList.contains(sth)) vocabList.indexOf(sth) + start
+ else default
+ }
+ func
+ }
+
+ def buckBuckets(bucketSize: Int)(col: String*): Int = {
+ Math.abs(col.reduce(_ + "_" + _).hashCode()) % bucketSize + 0
+ }
+
+ def bucketizedColumn(boundaries: Array[Float]): Float => Int = {
+ col1: Float => {
+ var index = 0
+ while (index < boundaries.length && col1 >= boundaries(index)) {
+ index += 1
+ }
+ index
+ }
+ }
+
+ def getDeepTensor(r: Row, columnInfo: ColumnFeatureInfo): Tensor[Float] = {
+ val deepColumns1 = columnInfo.indicatorCols
+ val deepColumns2 = columnInfo.embedCols ++ columnInfo.continuousCols
+ val deepLength = columnInfo.indicatorDims.sum + deepColumns2.length
+ val deepTensor = Tensor[Float](deepLength).fill(0)
+
+ // setup indicators
+ var acc = 0
+ (0 to deepColumns1.length - 1).map {
+ i =>
+ val index = r.getAs[Int](columnInfo.indicatorCols(i))
+ val accIndex = if (i == 0) index
+ else {
+ acc = acc + columnInfo.indicatorDims(i - 1)
+ acc + index
+ }
+ deepTensor.setValue(accIndex + 1, 1)
+ }
+
+ // setup embedding and continuous
+ (0 to deepColumns2.length - 1).map {
+ i =>
+ deepTensor.setValue(i + 1 + columnInfo.indicatorDims.sum,
+ r.getAs[Int](deepColumns2(i)).toFloat)
+ }
+ deepTensor
+ }
+
+ def getWideTensor(r: Row, columnInfo: ColumnFeatureInfo): Tensor[Float] = {
+ val wideColumns = columnInfo.wideBaseCols ++ columnInfo.wideCrossCols
+ val wideDims = columnInfo.wideBaseDims ++ columnInfo.wideCrossDims
+ val wideLength = wideColumns.length
+ var acc = 0
+ val indices: Array[Int] = (0 to wideLength - 1).map(i => {
+ val index = r.getAs[Int](wideColumns(i))
+ if (i == 0) {
+ index
+ }
+ else {
+ acc = acc + wideDims(i - 1)
+ acc + index
+ }
+ }).toArray
+ val values = indices.map(_ => 1.0f)
+ val shape = Array(wideDims.sum)
+
+ Tensor.sparse(Array(indices), values, shape)
+ }
+
+ def getWideTensorSequential(r: Row, columnInfo: ColumnFeatureInfo): Tensor[Float] = {
+ val wideColumns = columnInfo.wideBaseCols ++ columnInfo.wideCrossCols
+ val wideDims = columnInfo.wideBaseDims ++ columnInfo.wideCrossDims
+ val wideLength = wideColumns.length
+ var acc = 0
+ val indices: Array[Int] = (0 to wideLength - 1).map(i => {
+ val index = r.getAs[Int](wideColumns(i))
+ if (i == 0) index
+ else {
+ acc = acc + wideDims(i - 1)
+ acc + index
+ }
+ }).toArray
+ val values = indices.map(_ => 1.0f)
+ val shape = Array(wideDims.sum)
+
+ Tensor.sparse(Array(indices), values, shape)
+ }
+
+ def row2SampleSequential(r: Row,
+ columnInfo: ColumnFeatureInfo,
+ modelType: String): Sample[Float] = {
+ val wideTensor: Tensor[Float] = getWideTensorSequential(r, columnInfo)
+ val deepTensor: Tensor[Float] = getDeepTensor(r, columnInfo)
+
+ val label = if (clientId == 2) {
+ val l = r.getAs[Int](columnInfo.label)
+ val label = Tensor[Float](T(l))
+ Array(label.resize(1))
+ } else {
+ Array[Tensor[Float]]()
+ }
+
+
+
+ modelType match {
+ case "wide_n_deep" =>
+ TensorSample[Float](Array(wideTensor, deepTensor), label)
+ case "wide" =>
+ TensorSample[Float](Array(wideTensor), label)
+ case "deep" =>
+ TensorSample[Float](Array(deepTensor), label)
+ case _ =>
+ throw new IllegalArgumentException("unknown type")
+ }
+ }
+
+ // convert features to RDD[Sample[Float]]
+ def assemblyFeature(isImplicit: Boolean = false,
+ dataDf: DataFrame,
+ columnInfo: ColumnFeatureInfo,
+ modelType: String): RDD[RecordSample[Float]] = {
+ val educationVocab = Array("Bachelors", "HS-grad", "11th", "Masters", "9th",
+ "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
+ "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
+ "Preschool", "12th") // 16
+ val maritalStatusVocab = Array("Married-civ-spouse", "Divorced", "Married-spouse-absent",
+ "Never-married", "Separated", "Married-AF-spouse", "Widowed")
+ val relationshipVocab = Array("Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
+ "Other-relative") // 6
+ val workclassVocab = Array("Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
+ "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked") // 9
+ val genderVocab = Array("Female", "Male")
+
+ val ages = Array(18f, 25, 30, 35, 40, 45, 50, 55, 60, 65)
+
+ val educationVocabUdf = udf(categoricalFromVocabList(educationVocab))
+ val maritalStatusVocabUdf = udf(categoricalFromVocabList(maritalStatusVocab))
+ val relationshipVocabUdf = udf(categoricalFromVocabList(relationshipVocab))
+ val workclassVocabUdf = udf(categoricalFromVocabList(workclassVocab))
+ val genderVocabUdf = udf(categoricalFromVocabList(genderVocab))
+
+ val bucket1Udf = udf(buckBuckets(1000)(_: String))
+ val bucket2Udf = udf(buckBuckets(1000)(_: String, _: String))
+ val bucket3Udf = udf(buckBuckets(1000)(_: String, _: String, _: String))
+
+ val ageBucketUdf = udf(bucketizedColumn(ages))
+
+ val incomeUdf = udf((income: String) => if (income == ">50K" || income == ">50K.") 1 else 0)
+
+ val data = dataDf
+ .withColumn("age_bucket", ageBucketUdf(col("age")))
+ .withColumn("edu_occ", bucket2Udf(col("education"), col("occupation")))
+ .withColumn("age_edu_occ", bucket3Udf(col("age_bucket"), col("education"),
+ col("occupation")))
+ .withColumn("edu", educationVocabUdf(col("education")))
+ .withColumn("mari", maritalStatusVocabUdf(col("marital_status")))
+ .withColumn("rela", relationshipVocabUdf(col("relationship")))
+ .withColumn("work", workclassVocabUdf(col("workclass")))
+ .withColumn("occ", bucket1Udf(col("occupation")))
+ .withColumn("label", incomeUdf(col("income_bracket")))
+ val rddOfSample = data.rdd.map(r => {
+ RecordSample(row2SampleSequential(r, columnInfo, modelType))
+ })
+ rddOfSample
+ }
+
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/GenerateCkksSecret.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/GenerateCkksSecret.scala
new file mode 100644
index 00000000000..7342c076720
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/GenerateCkksSecret.scala
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.ppml.fl.example.ckks
+
+import com.intel.analytics.bigdl.ckks.CKKS
+
+// TODO: key should be provided by Key Management System
+object GenerateCkksSecret {
+ def main(args: Array[String]): Unit = {
+ if (args.length >= 1) {
+ val ckks = new CKKS()
+ val keys = ckks.createSecrets()
+ CKKS.saveSecret(keys, args(0))
+ println("save secret to " + args(0))
+ } else {
+ println("please provide a path to save secret.")
+ }
+ }
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/README.md b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/README.md
new file mode 100644
index 00000000000..6d36355568a
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/README.md
@@ -0,0 +1,42 @@
+### Data
+We use [Census]() data in this example
+
+To simulate the scenario of two parties, we use select different features of Census data.
+
+The original data has 15 columns. In preprocessing, some new feature are created from the combinations of some existed columns.
+
+* data of client 1: `age`, `education`, `occupation`, cross columns: `edu_occ`, `age_edu_occ`
+* data of client 2: `relationship`, `workclass`, `marital_status`
+
+### Download BigDL assembly
+
+Download BigDL assembly from [BigDL-Release](https://bigdl.readthedocs.io/en/latest/doc/release.html)
+
+### Generate secret
+
+```bash
+java -cp bigdl-ppml-[version]-jar-with-all-dependencies.jar com.intel.analytics.bigdl.ppml.fl.GenerateCkksSecret ckks.crt
+```
+
+### Start FLServer
+Before starting server, modify the config file, `ppml-conf.yaml`, this application has 2 clients globally, and set the absolute path to ckks secret. So use following config:
+```
+worldSize: 2
+ckksSercetPath: /[absolute path]/ckks.crt
+```
+Then start FLServer at server machine
+```bash
+java -cp bigdl-ppml-[version]-jar-with-all-dependencies.jar com.intel.analytics.bigdl.ppml.fl.FLServer
+```
+
+## Start Local Trainers
+Start the local Logistic Regression trainers at 2 training machines
+```
+java -cp bigdl-ppml-[version]-jar-with-all-dependencies.jar com.intel.analytics.bigdl.ppml.fl.example.VflLogisticRegressionCkks
+ -d [path to adult dataset]
+ -i 1
+ -s [path to ckks.crt]
+# change -i 1 to -i 2 at client-2
+```
+
+The example will train the data and evaluate the training result.
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/StartServer.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/StartServer.scala
new file mode 100644
index 00000000000..41c1ce46e1c
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/StartServer.scala
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.ppml.fl.example.ckks
+
+import com.intel.analytics.bigdl.ckks.CKKS
+import com.intel.analytics.bigdl.ppml.fl.FLServer
+
+object StartServer {
+ def main(args: Array[String]): Unit = {
+ val flServer = new FLServer()
+
+ flServer.setClientNum(2)
+ if (args.length > 0) {
+ val secretsPath = args(0)
+ flServer.setCkksAggregator(
+ CKKS.loadSecret(secretsPath))
+ }
+ flServer.build()
+ flServer.start()
+ flServer.blockUntilShutdown()
+ }
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/VflLogisticRegressionCkks.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/VflLogisticRegressionCkks.scala
new file mode 100644
index 00000000000..1d06dc33b4a
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/example/ckks/VflLogisticRegressionCkks.scala
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.ppml.fl.example.ckks
+
+import com.intel.analytics.bigdl.ckks.CKKS
+import com.intel.analytics.bigdl.dllib.NNContext
+import com.intel.analytics.bigdl.dllib.feature.dataset.{DataSet, Sample, SampleToMiniBatch, TensorSample}
+import com.intel.analytics.bigdl.dllib.keras.metrics.BinaryAccuracy
+import com.intel.analytics.bigdl.dllib.nn.{BCECriterion, Sigmoid, SparseLinear}
+import com.intel.analytics.bigdl.dllib.optim.{Adagrad, Ftrl, SGD, Top1Accuracy}
+import com.intel.analytics.bigdl.dllib.tensor.{Storage, Tensor}
+import com.intel.analytics.bigdl.dllib.utils.{Engine, RandomGenerator, T}
+import com.intel.analytics.bigdl.ppml.fl.algorithms.VFLLogisticRegression
+import com.intel.analytics.bigdl.ppml.fl.{FLContext, FLServer, NNModel}
+import com.intel.analytics.bigdl.ppml.fl.example.ckks.DataPreprocessing
+import io.grpc.netty.shaded.io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.{SparkConf, SparkContext}
+import scopt.OptionParser
+
+import java.util
+
+
+object VflLogisticRegressionCkks {
+ case class CmdArgs(dataPath: String = null,
+ clientId: Int = 1,
+ mode: String = "ckks",
+ secretePath: String = ""
+ )
+ val parser = new OptionParser[CmdArgs]("PPML CKKS example") {
+ opt[String]('d', "dataPath")
+ .text("data path")
+ .action((x, c) => c.copy(dataPath = x))
+ .required()
+ opt[Int]('i', "id")
+ .text("client id")
+ .action((x, c) => c.copy(clientId = x))
+ .required()
+ opt[String]('m', "mode")
+ .text("ckks or dllib")
+ .action((x, c) => c.copy(mode = x))
+ opt[String]('s', "secret")
+ .text("ckks secret path, not none when mode is ckks")
+ .action((x, c) => c.copy(secretePath = x))
+ }
+
+
+ def main(args: Array[String]): Unit = {
+ parser.parse(args, CmdArgs()).foreach { param =>
+
+ val inputDir = param.dataPath
+ val clientId = param.clientId
+
+ val trainDataPath = s"$inputDir/adult-${clientId}.data"
+ val testDataPath = s"$inputDir/adult-${clientId}.test"
+ val mode = param.mode
+ val ckksSecretPath = param.secretePath
+
+ FLContext.initFLContext(clientId)
+ val sqlContext = SparkSession.builder().getOrCreate()
+ val pre = new DataPreprocessing(sqlContext, trainDataPath, testDataPath, clientId)
+ val (trainDataset, validationDataset) = pre.loadCensusData()
+
+ val numFeature = if (clientId == 1) {
+ 3049 - 6 - 9 - 7
+ } else {
+ 6 + 9 + 7
+ }
+
+ val linear = if (clientId == 1) {
+ SparseLinear[Float](numFeature, 1, withBias = false)
+ } else {
+ SparseLinear[Float](numFeature, 1, withBias = true)
+ }
+ linear.getParameters()._1.randn(0, 0.001)
+
+ val lr: NNModel = mode match {
+ case "dllib" => new VFLLogisticRegression(numFeature, 0.005f, linear)
+ case "ckks" =>
+ FLContext.initCkks(ckksSecretPath)
+ new VFLLogisticRegression(numFeature, 0.005f, linear, "vfl_logistic_regression_ckks")
+ case _ => throw new Error()
+ }
+ lr.estimator.train(40, trainDataset.toLocal(), null)
+
+ val validationMethod = new BinaryAccuracy[Float]()
+ val preditions = lr.estimator.predict(validationDataset.toLocal())
+
+ if (clientId == 2) {
+ // client2 has target
+ val evalData = validationDataset.toLocal().data(false)
+ val result = preditions.toIterator.zip(evalData).map{datas =>
+ validationMethod.apply(datas._1, datas._2.getTarget())
+ }.reduce(_ + _)
+
+ println(result.toString())
+ }
+
+ }
+ }
+
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/HFLNNEstimator.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/HFLNNEstimator.scala
index ba1af3c1afe..c8999fa3b6a 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/HFLNNEstimator.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/HFLNNEstimator.scala
@@ -64,7 +64,7 @@ class HFLNNEstimator(algorithm: String,
localEstimator.fit(trainSet.toSeq, size.toInt, valSet.toSeq)
logger.debug(s"Local train step ends, syncing version: $iteration with server.")
val weights = getModelWeightTable(model, iteration)
- val serverWeights = flClient.nnStub.train(weights, algorithm).getData
+ val serverWeights = flClient.nnStub.train(weights, algorithm)
// model replace
updateModel(model, serverWeights)
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/NNServiceImpl.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/NNServiceImpl.scala
index 27b14e84161..896105ca6c4 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/NNServiceImpl.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/NNServiceImpl.scala
@@ -17,8 +17,7 @@
package com.intel.analytics.bigdl.ppml.fl.nn
-import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity
-
+import com.intel.analytics.bigdl.ckks.CKKS
import java.util
import java.util.Map
import com.intel.analytics.bigdl.dllib.nn.{BCECriterion, MSECriterion, Sigmoid, View}
@@ -52,16 +51,28 @@ class NNServiceImpl(clientNum: Int) extends NNServiceGrpc.NNServiceImplBase {
})
}
+ def initCkksAggregator(secretPath: String): Unit = {
+ val secret = CKKS.loadSecret(secretPath)
+ initCkksAggregator(secret)
+ }
+ def initCkksAggregator(secret: Array[Array[Byte]]): Unit = {
+ val ckks = new CKKS()
+ val ckksCommonInstance = ckks.createCkksCommonInstance(secret)
+ val ckksAggregator = new VFLNNAggregatorCkks(ckksCommonInstance)
+ ckksAggregator.setClientNum(clientNum)
+ aggregatorMap.put("vfl_logistic_regression_ckks", ckksAggregator)
+ }
+
override def train(request: TrainRequest,
responseObserver: StreamObserver[TrainResponse]): Unit = {
val clientUUID = request.getClientuuid
- logger.debug("Server get train request from client: " + clientUUID)
+ logger.info("Server get train request from client: " + clientUUID)
val data = request.getData
val version = data.getMetaData.getVersion
val aggregator = aggregatorMap.get(request.getAlgorithm)
try {
aggregator.putClientData(TRAIN, clientUUID, version, new DataHolder(data))
- logger.debug(s"$clientUUID getting server new data to update local")
+ logger.info(s"$clientUUID getting server new data to update local")
val responseData = aggregator.getStorage(TRAIN).serverData
if (responseData == null) {
val response = "Data requested doesn't exist"
@@ -76,6 +87,7 @@ class NNServiceImpl(clientNum: Int) extends NNServiceGrpc.NNServiceImplBase {
} catch {
case e: Exception =>
val errorMsg = ExceptionUtils.getStackTrace(e)
+ logger.error(errorMsg)
val response = TrainResponse.newBuilder.setResponse(errorMsg).setCode(1).build
responseObserver.onNext(response)
responseObserver.onCompleted()
@@ -117,6 +129,7 @@ class NNServiceImpl(clientNum: Int) extends NNServiceGrpc.NNServiceImplBase {
} catch {
case e: Exception =>
val errorMsg = ExceptionUtils.getStackTrace(e)
+ logger.error(errorMsg)
val response = EvaluateResponse.newBuilder.setResponse(errorMsg).setCode(1).build
responseObserver.onNext(response)
responseObserver.onCompleted()
@@ -146,6 +159,7 @@ class NNServiceImpl(clientNum: Int) extends NNServiceGrpc.NNServiceImplBase {
} catch {
case e: Exception =>
val errorMsg = ExceptionUtils.getStackTrace(e)
+ logger.error(errorMsg)
val response = PredictResponse.newBuilder.setResponse(errorMsg).setCode(1).build
responseObserver.onNext(response)
responseObserver.onCompleted()
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNAggregator.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNAggregator.scala
index 3f9af59398a..767b73eff76 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNAggregator.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNAggregator.scala
@@ -33,7 +33,6 @@ import org.apache.logging.log4j.LogManager
/**
*
* @param model
- * @param optimMethod
* @param criterion loss function, HFL takes loss at estimator, VFL takes loss at aggregator
* @param validationMethods
*/
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNAggregatorCkks.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNAggregatorCkks.scala
new file mode 100644
index 00000000000..3ba3de32378
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNAggregatorCkks.scala
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.analytics.bigdl.ppml.fl.nn
+
+
+import com.intel.analytics.bigdl.ppml.fl.nn.ckks.{CAddTable, FusedBCECriterion}
+import com.intel.analytics.bigdl.dllib.optim.{OptimMethod, ValidationMethod, ValidationResult}
+import com.intel.analytics.bigdl.dllib.utils.{Log4Error, T}
+import com.intel.analytics.bigdl.ppml.fl.common.FLPhase
+import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto._
+import com.intel.analytics.bigdl.ppml.fl.utils.ProtoUtils
+
+
+/**
+ *
+ * @param optimMethod
+ * @param validationMethods
+ */
+class VFLNNAggregatorCkks(
+ ckksCommon: Long,
+ optimMethod: OptimMethod[Float] = null,
+ validationMethods: Array[ValidationMethod[Float]] = null
+ ) extends NNAggregator{
+ val m1 = CAddTable(ckksCommon)
+ val criterion = FusedBCECriterion(ckksCommon)
+
+
+ var validationResult = List[Array[ValidationResult]]()
+
+ /**
+ * Aggregate the clients data to update server data by aggType
+ * @param flPhase FLPhase enum type, one of TRAIN, EVAL, PREDICT
+ */
+ override def aggregate(flPhase: FLPhase): Unit = {
+ val storage = getStorage(flPhase)
+ val (inputTable, target, shapeGrad, shapeLoss) = ProtoUtils.ckksProtoToBytes(storage)
+
+ val output = m1.updateOutput(inputTable: _*)
+
+ val metaBuilder = MetaData.newBuilder()
+ var aggregatedTable: TensorMap = null
+ flPhase match {
+ case FLPhase.TRAIN =>
+ val loss = criterion.forward(output, target)
+ val grad = criterion.backward(output, target)
+ val meta = metaBuilder.setName("gradInput").setVersion(storage.version).build()
+ // Pass byte back to clients
+ aggregatedTable = TensorMap.newBuilder()
+ .setMetaData(meta)
+ .putEncryptedTensorMap("gradInput", ProtoUtils.bytesToCkksProto(grad, shapeGrad))
+ .putEncryptedTensorMap("loss", ProtoUtils.bytesToCkksProto(loss, shapeLoss))
+ .build()
+
+ case FLPhase.EVAL =>
+ Log4Error.invalidOperationError(false, "Not supported")
+
+ case FLPhase.PREDICT =>
+ val meta = metaBuilder.setName("predictResult").setVersion(storage.version).build()
+ aggregatedTable = TensorMap.newBuilder()
+ .setMetaData(meta)
+ .putEncryptedTensorMap("predictOutput", ProtoUtils.bytesToCkksProto(output, shapeGrad))
+ .build()
+ }
+ storage.clearClientAndUpdateServer(aggregatedTable)
+ }
+
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNEstimator.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNEstimator.scala
index 16c79b7405a..80935274d32 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNEstimator.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLNNEstimator.scala
@@ -41,7 +41,6 @@ class VFLNNEstimator(algorithm: String,
threadNum: Int = 1) extends Estimator with FLClientClosable {
val logger = LogManager.getLogger(getClass)
val (weight, grad) = getParametersFromModel(model)
-
protected val evaluateResults = mutable.Map[String, ArrayBuffer[Float]]()
/**
@@ -62,7 +61,6 @@ class VFLNNEstimator(algorithm: String,
(0 until endEpoch).foreach { epoch =>
val dataSet = trainDataSet.data(true)
var count = 0
- var hasLabel = true
while (count < size) {
logger.debug(s"training next batch, progress: $count/$size, epoch: $epoch/$endEpoch")
val miniBatch = dataSet.next()
@@ -73,24 +71,7 @@ class VFLNNEstimator(algorithm: String,
.update("neval", iteration + 1)
val input = miniBatch.getInput()
val target = miniBatch.getTarget()
- if (target == null) hasLabel = false
- model.training()
- val output = model.forward(input)
-
- // Upload to PS
- val metadata = MetaData.newBuilder
- .setName(s"${model.getName()}_output").setVersion(iteration).build
- val tableProto = outputTargetToTableProto(model.output, target, metadata)
- model.zeroGradParameters()
- val gradInput = flClient.nnStub.train(tableProto, algorithm).getData
-
- // model replace
- val errors = getTensor("gradInput", gradInput)
- val loss = getTensor("loss", gradInput).value()
- model.backward(input, errors)
- logger.debug(s"Model doing backward, version: $iteration")
- optimMethod.optimize(_ => (loss, grad), weight)
-
+ trainStep(input, target, iteration)
iteration += 1
count += miniBatch.size()
}
@@ -98,12 +79,33 @@ class VFLNNEstimator(algorithm: String,
model.evaluate()
evaluate(valDataSet)
}
-
}
-
model
}
+ protected def trainStep(
+ input: Activity,
+ target: Activity,
+ iteration: Int): Unit = {
+ model.training()
+ model.zeroGradParameters()
+ val output = model.forward(input)
+
+ // Upload to PS
+ val metadata = MetaData.newBuilder
+ .setName(s"${model.getName()}_output").setVersion(iteration).build
+ val tableProto = outputTargetToTableProto(model.output, target, metadata)
+ val gradInput = flClient.nnStub.train(tableProto, algorithm)
+
+ // model replace
+ val errors = getTensor("gradInput", gradInput)
+ val loss = getTensor("loss", gradInput).mean()
+ logger.info(s"Loss: ${loss}")
+ model.backward(input, errors)
+ logger.debug(s"Model doing backward, version: $iteration")
+ optimMethod.optimize(_ => (loss, grad), weight)
+ }
+
/**
* Evaluate VFL model
* For each batch, client estimator upload output tensor to server aggregator,
@@ -156,17 +158,20 @@ class VFLNNEstimator(algorithm: String,
model.evaluate()
val miniBatch = data.next()
val input = miniBatch.getInput()
- val target = miniBatch.getTarget()
- val output = model.forward(input)
-
- val metadata = MetaData.newBuilder
- .setName(s"${model.getName()}_output").setVersion(iteration).build
- val tableProto = outputTargetToTableProto(model.output, target, metadata)
- val result = flClient.nnStub.predict(tableProto, algorithm).getData
- resultSeq = resultSeq :+ getTensor("predictOutput", result)
+ val result = predict(input, iteration)
+ resultSeq = resultSeq :+ result.toTensor[Float]
iteration += 1
count += miniBatch.size()
}
resultSeq.toArray
}
+
+ protected def predict(input: Activity, iteration: Int): Activity = {
+ val output = model.forward(input)
+ val metadata = MetaData.newBuilder
+ .setName(s"${model.getName()}_output").setVersion(iteration).build
+ val tableProto = outputTargetToTableProto(model.output, null, metadata)
+ val result = flClient.nnStub.predict(tableProto, algorithm)
+ getTensor("predictOutput", result)
+ }
}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/CAddTable.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/CAddTable.scala
new file mode 100644
index 00000000000..208142ea883
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/CAddTable.scala
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.ppml.fl.nn.ckks
+
+import com.intel.analytics.bigdl.ckks.CKKS
+
+class CAddTable(val ckksCommonPtr: Long) {
+ val ckks = new CKKS()
+
+ def updateOutput(input: Array[Byte]*): Array[Byte] = {
+ // Log4Error.invalidInputError(input.size().sameElements(target.size()),
+ // s"input size should be equal to target size, but got input size: ${input.size().toList}," +
+ // s" target size: ${target.size().toList}")
+ var ckksOutput = input(0)
+ if (input.size > 1) {
+ (1 until input.size).foreach{i =>
+ ckksOutput = ckks.cadd(ckksCommonPtr, ckksOutput, input(i))
+ }
+ }
+ ckksOutput
+ }
+
+ def updateGradInput(input: Array[Byte]*): Array[Byte] = {
+ input.last
+ }
+}
+
+
+object CAddTable {
+ def apply(ckksCommonPtr: Long): CAddTable = {
+ new CAddTable(ckksCommonPtr)
+ }
+}
+
+
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/Encryptor.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/Encryptor.scala
new file mode 100644
index 00000000000..7eb69239638
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/Encryptor.scala
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.ppml.fl.nn.ckks
+
+import com.intel.analytics.bigdl.ckks.CKKS
+import com.intel.analytics.bigdl.dllib.nn.abstractnn.{AbstractModule, Activity}
+import com.intel.analytics.bigdl.dllib.tensor.{Storage, Tensor}
+import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath.TensorNumeric
+
+import scala.reflect.{ClassTag, classTag}
+
+class Encryptor[T: ClassTag](val ckksEncryptorPtr: Long)(implicit ev: TensorNumeric[T])
+ extends AbstractModule[Tensor[T], Tensor[Byte], T] {
+ val ckks = new CKKS()
+
+ val floatInput = Activity.allocate[Tensor[Float], Float]()
+
+ override def updateOutput(input: Tensor[T]): Tensor[Byte] = {
+ val floatInput = if (classTag[T] == classTag[Float]) {
+ input.toTensor[Float]
+ } else {
+ input.cast[Float](this.floatInput)
+ }
+ val enInput = ckks.ckksEncrypt(ckksEncryptorPtr, floatInput.storage().array())
+ output = Tensor[Byte](Storage[Byte](enInput)).resize(input.size())
+ output
+ }
+
+ override def updateGradInput(input: Tensor[T], gradOutput: Tensor[Byte]): Tensor[T] = {
+ val deGradOutput = ckks.ckksDecrypt(ckksEncryptorPtr, gradOutput.storage().array())
+ val floatGradInput = Tensor[Float](Storage[Float](deGradOutput)).resize(input.size())
+ gradInput = if (classTag[T] == classTag[Float]) {
+ floatGradInput.toTensor[T]
+ } else {
+ floatGradInput.cast[T](gradInput)
+ }
+ gradInput
+ }
+
+}
+
+object Encryptor {
+ def apply[T: ClassTag](
+ ckksEncryptorPtr: Long)(implicit ev: TensorNumeric[T]): Encryptor[T] = {
+ new Encryptor[T](ckksEncryptorPtr)
+ }
+}
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/FusedBCECriterion.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/FusedBCECriterion.scala
new file mode 100644
index 00000000000..b3899bb7a81
--- /dev/null
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/nn/ckks/FusedBCECriterion.scala
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.intel.analytics.bigdl.ppml.fl.nn.ckks
+import com.intel.analytics.bigdl.ckks.CKKS
+import com.intel.analytics.bigdl.dllib.tensor.{Storage, Tensor}
+
+class FusedBCECriterion(val ckksCommonPtr: Long) {
+ val ckks = new CKKS()
+ var ckksOutput : Array[Array[Byte]] = null
+
+ def forward(input: Array[Byte], target: Array[Byte]): Array[Byte] = {
+// Log4Error.invalidInputError(input.size().sameElements(target.size()),
+// s"input size should be equal to target size, but got input size: ${input.size().toList}," +
+// s" target size: ${target.size().toList}")
+ ckksOutput = ckks.train(ckksCommonPtr, input, target)
+ ckksOutput(0)
+ }
+
+ def forward(input: Tensor[Byte], target: Tensor[Byte]): Tensor[Byte] = {
+ // Log4Error.invalidInputError(input.size().sameElements(target.size()),
+ // s"input size should be equal to target size, but got input size: ${input.size().toList}," +
+ // s" target size: ${target.size().toList}")
+ val loss = forward(input.storage.array(), target.storage.array())
+ Tensor[Byte](Storage[Byte](loss)).resize(input.size())
+ }
+
+ def backward(input: Array[Byte], target: Array[Byte]): Array[Byte] = {
+ ckksOutput(1)
+ }
+
+ def backward(input: Tensor[Byte], target: Tensor[Byte]): Tensor[Byte] = {
+ Tensor[Byte](Storage(ckksOutput(1))).resize(input.size())
+ }
+}
+
+
+object FusedBCECriterion {
+ def apply(ckksCommonPtr: Long): FusedBCECriterion = {
+ new FusedBCECriterion(ckksCommonPtr)
+ }
+}
+
diff --git a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/utils/FlContextForTest.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/utils/FlContextForTest.scala
similarity index 100%
rename from scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/utils/FlContextForTest.scala
rename to scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/utils/FlContextForTest.scala
diff --git a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/utils/ProtoUtils.scala b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/utils/ProtoUtils.scala
index fa649fc762c..d539f077a0b 100644
--- a/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/utils/ProtoUtils.scala
+++ b/scala/ppml/src/main/scala/com/intel/analytics/bigdl/ppml/fl/utils/ProtoUtils.scala
@@ -16,6 +16,7 @@
package com.intel.analytics.bigdl.ppml.fl.utils
+import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity
import com.intel.analytics.bigdl.dllib.tensor.Tensor
@@ -35,6 +36,8 @@ import com.intel.analytics.bigdl.dllib.utils.Log4Error
import com.intel.analytics.bigdl.ppml.fl.FLClient
import com.intel.analytics.bigdl.ppml.fl.generated.FlBaseProto
+import scala.collection.mutable.ArrayBuffer
+
object ProtoUtils {
private val logger = LogManager.getLogger(getClass)
def outputTargetToTableProto(output: Activity,
@@ -54,6 +57,36 @@ object ProtoUtils {
}
builder.build()
}
+
+ def ckksProtoToBytes(storage: Storage[TensorMap]): (
+ Array[Array[Byte]], Array[Byte], Array[Int], Array[Int]) = {
+ // TODO: impl
+ val arrayBuffer = new ArrayBuffer[Array[Byte]](storage.clientData.size())
+ var targetBytes: Array[Byte] = null
+ var shapeGrad: Array[Int] = null
+ var shapeLoss: Array[Int] = null
+ storage.clientData.values().asScala.foreach(clientMap => {
+ val tensorMap = clientMap.getEncryptedTensorMapMap().asScala
+ if (tensorMap.contains("target")) {
+ Log4Error.invalidOperationError(targetBytes == null, "Target already exists")
+ targetBytes = tensorMap.get("target").get.getTensor.toByteArray
+ }
+ if (shapeGrad == null) {
+ shapeGrad = tensorMap.get("output").get.getShapeList.asScala.toArray.map(_.toInt)
+ shapeLoss = Array(shapeGrad(0))
+ }
+ arrayBuffer.append(tensorMap.get("output").get.getTensor.toByteArray)
+ })
+ (arrayBuffer.toArray, targetBytes, shapeGrad, shapeLoss)
+ }
+
+ def bytesToCkksProto(
+ bytes: Array[Byte],
+ shape: Array[Int]): EncryptedTensor = {
+ EncryptedTensor.newBuilder().setTensor(ByteString.copyFrom(bytes))
+ .addAllShape(shape.map(new Integer(_)).toIterable.asJava).build()
+ }
+
def tableProtoToOutputTarget(storage: Storage[TensorMap]): (DllibTable, Tensor[Float]) = {
val aggData = protoTableMapToTensorIterableMap(storage.clientData)
val target = Tensor[Float]()
@@ -133,7 +166,9 @@ object ProtoUtils {
val dataMap = modelData.getTensorMapMap.get(name)
val data = dataMap.getTensorList.asScala.map(Float2float).toArray
val shape = dataMap.getShapeList.asScala.map(Integer2int).toArray
- Tensor[Float](data, shape)
+ var totalSize = 1
+ shape.foreach(dimSize => totalSize *= dimSize)
+ Tensor[Float](data.slice(0, totalSize), shape)
}
diff --git a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/nn/CkksSpec.scala b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/nn/CkksSpec.scala
new file mode 100644
index 00000000000..f8747ab7f6b
--- /dev/null
+++ b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/nn/CkksSpec.scala
@@ -0,0 +1,262 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.analytics.bigdl.ppml.fl.nn
+
+import com.intel.analytics.bigdl.ckks.CKKS
+import com.intel.analytics.bigdl.dllib.NNContext
+import com.intel.analytics.bigdl.dllib.feature.dataset.{DataSet, Sample, SampleToMiniBatch, TensorSample}
+import com.intel.analytics.bigdl.dllib.nn._
+import com.intel.analytics.bigdl.dllib.optim.{Adagrad, Ftrl, SGD}
+import com.intel.analytics.bigdl.dllib.tensor.{Storage, Tensor}
+import com.intel.analytics.bigdl.dllib.utils.{Engine, RandomGenerator, T}
+import com.intel.analytics.bigdl.ppml.BigDLSpecHelper
+import com.intel.analytics.bigdl.ppml.fl.nn.ckks.{Encryptor, FusedBCECriterion, CAddTable => CkksAddTable}
+import org.apache.spark.ml.linalg.DenseVector
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.FlatSpec
+
+import java.util
+import scala.math.abs
+import scala.util.Random
+
+class CkksSpec extends BigDLSpecHelper {
+ "ckks add" should "return right result" in {
+
+ val ckks = new CKKS()
+ val secrets = ckks.createSecrets()
+ val encryptorPtr = ckks.createCkksEncryptor(secrets)
+ val ckksRunnerPtr = ckks.createCkksCommonInstance(secrets)
+
+ val input1 = Array(0.1f, 0.2f, 1.1f, -1f)
+ val input2 = Array(-0.1f, 1.2f, 2.1f, 1f)
+ val enInput1 = ckks.ckksEncrypt(encryptorPtr, input1)
+ val enInput2 = ckks.ckksEncrypt(encryptorPtr, input2)
+
+ val cadd = CkksAddTable(ckksRunnerPtr)
+
+ val enOutput = cadd.updateOutput(enInput1, enInput2)
+ val output = ckks.ckksDecrypt(encryptorPtr, enOutput)
+ (0 until 4).foreach{i =>
+ output(i) should be (input1(i) + input2(i) +- 1e-5f)
+ }
+ }
+
+ "ckks layer" should "generate correct output and grad" in {
+ val eps = 1e-12f
+ val module = new Sigmoid[Float]
+ val criterion = new BCECriterion[Float]()
+
+ val input = Tensor[Float](2, 2)
+ input(Array(1, 1)) = 0.063364277360961f
+ input(Array(1, 2)) = 0.90631252736785f
+ input(Array(2, 1)) = 0.22275671223179f
+ input(Array(2, 2)) = 0.37516756891273f
+ val target = Tensor[Float](2, 2)
+ target(Array(1, 1)) = 1
+ target(Array(1, 2)) = 1
+ target(Array(2, 1)) = 0
+ target(Array(2, 2)) = 1
+
+ val exceptedOutput = module.forward(input)
+
+ val exceptedLoss = criterion.forward(exceptedOutput, target)
+ val exceptedGradOutput = criterion.backward(exceptedOutput, target)
+ val exceptedGradInput = module.backward(input, exceptedGradOutput)
+
+
+ val ckks = new CKKS()
+ val secrets = ckks.createSecrets()
+ val encryptorPtr = ckks.createCkksEncryptor(secrets)
+ val ckksRunnerPtr = ckks.createCkksCommonInstance(secrets)
+ val enTarget =
+ Tensor[Byte](Storage[Byte](ckks.ckksEncrypt(encryptorPtr, target.storage().array())))
+ .resize(target.size())
+
+ val module2 = Encryptor[Float](encryptorPtr)
+ val criterion2 = FusedBCECriterion(ckksRunnerPtr)
+
+ val output2 = module2.forward(input).toTensor[Byte]
+ val loss2 = criterion2.forward(output2, enTarget)
+ val gradOutput2 = criterion2.backward(output2, enTarget)
+ val gradInput2 = module2.backward(input, gradOutput2).toTensor[Float]
+
+
+ val enLoss = ckks.ckksDecrypt(encryptorPtr, loss2.storage().array)
+ gradInput2.div(4)
+ val loss = enLoss.slice(0, 4).sum / 4
+ loss should be (exceptedLoss +- 0.02f)
+ (1 to 2).foreach{i =>
+ (1 to 2).foreach{j =>
+ gradInput2.valueAt(i, j) should be (exceptedGradInput.valueAt(i, j) +- 0.02f)
+ }
+ }
+ }
+
+ "ckks jni api" should "generate correct output and grad" in {
+ val eps = 1e-12f
+ val module = new Sigmoid[Float]
+ val criterion = new BCECriterion[Float]()
+ val input = Tensor[Float](2, 2)
+ input(Array(1, 1)) = 0.063364277360961f
+ input(Array(1, 2)) = 0.90631252736785f
+ input(Array(2, 1)) = 0.22275671223179f
+ input(Array(2, 2)) = 0.37516756891273f
+ val target = Tensor[Float](2, 2)
+ target(Array(1, 1)) = 1
+ target(Array(1, 2)) = 1
+ target(Array(2, 1)) = 0
+ target(Array(2, 2)) = 1
+
+ val exceptedOutput = module.forward(input)
+
+ val exceptedLoss = criterion.forward(exceptedOutput, target)
+ val exceptedGradOutput = criterion.backward(exceptedOutput, target)
+ val exceptedGradInput = module.backward(input, exceptedGradOutput)
+
+ val ckks = new CKKS()
+ val secrets = ckks.createSecrets()
+ val encryptorPtr = ckks.createCkksEncryptor(secrets)
+ val ckksRunnerPtr = ckks.createCkksCommonInstance(secrets)
+ val enInput = ckks.ckksEncrypt(encryptorPtr, input.storage().array())
+ val enTarget = ckks.ckksEncrypt(encryptorPtr, target.storage().array())
+ val o = ckks.train(ckksRunnerPtr, enInput, enTarget)
+ val enLoss = ckks.ckksDecrypt(encryptorPtr, o(0))
+ val enGradInput2 = ckks.ckksDecrypt(encryptorPtr, o(1))
+ val gradInput2 = Tensor[Float](enGradInput2.slice(0, 4), Array(2, 2))
+ gradInput2.div(4)
+ val loss = enLoss.slice(0, 4).sum / 4
+ loss should be (exceptedLoss +- 0.02f)
+ (1 to 2).foreach{i =>
+ (1 to 2).foreach{j =>
+ gradInput2.valueAt(i, j) should be (exceptedGradInput.valueAt(i, j) +- 0.02f)
+ }
+
+ }
+ }
+
+ "ckks forward" should "generate correct output" in {
+ val module = new Sigmoid[Float]
+ val input = Tensor[Float](2, 4)
+ input(Array(1, 1)) = 0.063364277360961f
+ input(Array(1, 2)) = 0.90631252736785f
+ input(Array(1, 3)) = 0.22275671223179f
+ input(Array(1, 4)) = 0.37516756891273f
+ input(Array(2, 1)) = 0.99284988618456f
+ input(Array(2, 2)) = 0.97488326719031f
+ input(Array(2, 3)) = 0.94414822547697f
+ input(Array(2, 4)) = 0.68123375508003f
+ val exceptedOutput = module.forward(input)
+
+ val ckks = new CKKS()
+ val secrets = ckks.createSecrets()
+ val encryptorPtr = ckks.createCkksEncryptor(secrets)
+ val ckksRunnerPtr = ckks.createCkksCommonInstance(secrets)
+ val enInput = ckks.ckksEncrypt(encryptorPtr, input.storage().array())
+ val enOutput = ckks.sigmoidForward(ckksRunnerPtr, enInput)
+ val outputArray = ckks.ckksDecrypt(encryptorPtr, enOutput(0))
+ val output = Tensor[Float](outputArray.slice(0, 8), Array(2, 4))
+ println(output)
+ println(exceptedOutput)
+ (1 to 2).foreach{i =>
+ (1 to 4).foreach{j =>
+ output.valueAt(i, j) should be (exceptedOutput.valueAt(i, j) +- 0.03f)
+ }
+ }
+ }
+
+ "ckks train" should "converge" in {
+ val random = new Random()
+ random.setSeed(10)
+ val featureLen = 10
+ val bs = 20
+ val totalSize = 1000
+ val dummyData = Array.tabulate(totalSize)(i =>
+ {
+ val features = Array.tabulate(featureLen)(_ => random.nextFloat())
+ val label = math.round(features.sum / featureLen).toFloat
+ Sample[Float](Tensor[Float](features, Array(featureLen)), label)
+ }
+ )
+ val dataset = DataSet.array(dummyData) ->
+ SampleToMiniBatch[Float](bs, parallelizing = false)
+
+ val module = Sequential[Float]()
+ module.add(Linear[Float](10, 1))
+ module.add(Sigmoid[Float]())
+ val criterion = new BCECriterion[Float]()
+ val sgd = new SGD[Float](0.1)
+ val sgd2 = new SGD[Float](0.1)
+ val (weight, gradient) = module.getParameters()
+
+ val module2 = Linear[Float](10, 1)
+ val (weight2, gradient2) = module2.getParameters()
+ weight2.copy(weight)
+ val ckks = new CKKS()
+ val secrets = ckks.createSecrets()
+ val encryptorPtr = ckks.createCkksEncryptor(secrets)
+ val ckksRunnerPtr = ckks.createCkksCommonInstance(secrets)
+
+ val epochNum = 2
+ val lossArray = new Array[Float](epochNum)
+ val loss2Array = new Array[Float](epochNum)
+ (0 until epochNum).foreach{epoch =>
+ var countLoss = 0f
+ var countLoss2 = 0f
+ dataset.shuffle()
+ val trainData = dataset.toLocal().data(false)
+ while(trainData.hasNext) {
+ val miniBatch = trainData.next()
+ val input = miniBatch.getInput()
+ val target = miniBatch.getTarget()
+ val output = module.forward(input)
+ val loss = criterion.forward(output, target)
+ countLoss += loss
+ val gradOutput = criterion.backward(output, target)
+ module.backward(input, gradOutput)
+ sgd.optimize(_ => (loss, gradient), weight)
+
+ val output2 = module2.forward(input).toTensor[Float]
+ val enInput = ckks.ckksEncrypt(encryptorPtr, output2.storage().array())
+ val enTarget = ckks.ckksEncrypt(encryptorPtr, target.toTensor[Float].storage().array())
+ val o = ckks.train(ckksRunnerPtr, enInput, enTarget)
+
+ val enLoss = ckks.ckksDecrypt(encryptorPtr, o(0))
+ val enGradInput2 = ckks.ckksDecrypt(encryptorPtr, o(1))
+ val gradInput2 = Tensor[Float](enGradInput2.slice(0, bs), Array(bs, 1))
+ gradInput2.div(bs)
+ module2.backward(input, gradInput2)
+ val loss2 = enLoss.slice(0, bs).sum / bs
+ sgd2.optimize(_ => (loss2, gradient2), weight2)
+ countLoss2 += loss2
+ module.zeroGradParameters()
+ module2.zeroGradParameters()
+ }
+ lossArray(epoch) = countLoss / (totalSize / bs)
+ loss2Array(epoch) = countLoss2 / (totalSize / bs)
+ println(countLoss / (totalSize / bs))
+ println(" " + countLoss2 / (totalSize / bs))
+ }
+ println("loss1: ")
+ println(lossArray.mkString("\n"))
+ println("loss2: ")
+ println(loss2Array.mkString("\n"))
+ lossArray.last - lossArray(0) should be (loss2Array.last -
+ loss2Array(0) +- 1e-2f)
+ }
+
+}
diff --git a/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLCkksSpec.scala b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLCkksSpec.scala
new file mode 100644
index 00000000000..a003b053311
--- /dev/null
+++ b/scala/ppml/src/test/scala/com/intel/analytics/bigdl/ppml/fl/nn/VFLCkksSpec.scala
@@ -0,0 +1,66 @@
+/*
+ * Copyright 2016 The BigDL Authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.analytics.bigdl.ppml.fl.nn
+
+import com.intel.analytics.bigdl.ckks.CKKS
+import com.intel.analytics.bigdl.dllib.tensor.Tensor
+import com.intel.analytics.bigdl.ppml.fl.algorithms.{PSI, VFLLinearRegression, VFLLogisticRegression}
+import com.intel.analytics.bigdl.ppml.fl.example.VFLLogisticRegression
+import com.intel.analytics.bigdl.ppml.fl.utils.ProtoUtils.{getTensor, outputTargetToTableProto}
+import com.intel.analytics.bigdl.ppml.fl.vfl.NNStub
+import com.intel.analytics.bigdl.ppml.fl.{FLContext, FLServer, FLSpec}
+
+
+class VFLCkksSpec extends FLSpec {
+ "Encrypt and decrypt" should "work" in {
+ FLContext.initFLContext(1, target)
+ val ckks = new CKKS()
+ val secret = ckks.createSecrets()
+ val dataArray = Array(0.063364277360961f,
+ 0.90631252736785f,
+ 0.22275671223179f,
+ 0.37516756891273f)
+ val input = Tensor(dataArray, Array(1, 4))
+ val label = Tensor(Array(1f), Array(1, 1))
+ val tensorMap = outputTargetToTableProto(input, label, null)
+ val stub = new NNStub(FLContext.flClient.getChannel, 1, secret)
+ val encrypted = stub.encrypt(tensorMap)
+ val decrypted = stub.decrypt(encrypted)
+ val inputDecrypted = getTensor("output", decrypted)
+ val targetDecrypted = getTensor("target", decrypted)
+ input.almostEqual(inputDecrypted, 1e-5) should be (true)
+ }
+ "CKKS VFL LR" should "work" in {
+ val secret = new CKKS().createSecrets()
+ val flServer = new FLServer()
+ flServer.setPort(port)
+ flServer.setCkksAggregator(secret)
+ flServer.build()
+ flServer.start()
+
+ FLContext.initFLContext(1, target)
+ FLContext.flClient.initCkks(secret)
+ val lr = new VFLLogisticRegression(4, algorithm = "vfl_logistic_regression_ckks")
+ val input = Tensor(Array(0.063364277360961f,
+ 0.90631252736785f,
+ 0.22275671223179f,
+ 0.37516756891273f), Array(1, 4))
+ val label = Tensor(Array(1f), Array(1, 1))
+ lr.fit(input, label)
+ flServer.stop()
+ }
+}