diff --git a/core/src/main/java/com/datastrato/graviton/proto/ProtoEntitySerDe.java b/core/src/main/java/com/datastrato/graviton/proto/ProtoEntitySerDe.java index d530fc92359..a6a11b07e4a 100644 --- a/core/src/main/java/com/datastrato/graviton/proto/ProtoEntitySerDe.java +++ b/core/src/main/java/com/datastrato/graviton/proto/ProtoEntitySerDe.java @@ -12,7 +12,6 @@ import com.google.protobuf.Message; import java.io.IOException; import java.util.Map; -import java.util.Optional; public class ProtoEntitySerDe implements EntitySerDe { @@ -38,47 +37,14 @@ public class ProtoEntitySerDe implements EntitySerDe { private final Map, Class> entityToProto; - private final Map, Class> protoToEntity; - - public ProtoEntitySerDe() throws IOException { - ClassLoader loader = - Optional.ofNullable(Thread.currentThread().getContextClassLoader()) - .orElse(getClass().getClassLoader()); - - // TODO. This potentially has issues in creating serde objects, because the class load here - // may have no context for entities which are implemented in the specific catalog module. We - // should lazily create the serde class in the classloader when serializing and deserializing. + public ProtoEntitySerDe() { this.entityToSerDe = Maps.newHashMap(); - for (Map.Entry entry : ENTITY_TO_SERDE.entrySet()) { - String key = entry.getKey(); - String s = entry.getValue(); - Class entityClass = (Class) loadClass(key, loader); - Class> serdeClass = - (Class>) loadClass(s, loader); - - try { - ProtoSerDe serde = serdeClass.newInstance(); - entityToSerDe.put(entityClass, serde); - } catch (Exception exception) { - throw new IOException("Failed to instantiate serde class " + s, exception); - } - } - this.entityToProto = Maps.newHashMap(); - this.protoToEntity = Maps.newHashMap(); - for (Map.Entry entry : ENTITY_TO_PROTO.entrySet()) { - String e = entry.getKey(); - String p = entry.getValue(); - Class entityClass = (Class) loadClass(e, loader); - Class protoClass = (Class) loadClass(p, loader); - entityToProto.put(entityClass, protoClass); - protoToEntity.put(protoClass, entityClass); - } } @Override public byte[] serialize(T t) throws IOException { - Any any = Any.pack(toProto(t)); + Any any = Any.pack(toProto(t, Thread.currentThread().getContextClassLoader())); return any.toByteArray(); } @@ -86,44 +52,65 @@ public byte[] serialize(T t) throws IOException { public T deserialize(byte[] bytes, Class clazz, ClassLoader classLoader) throws IOException { Any any = Any.parseFrom(bytes); + Class protoClass = getProtoClass(clazz, classLoader); - if (!entityToSerDe.containsKey(clazz) || !entityToProto.containsKey(clazz)) { - throw new IOException("No proto and serde class found for entity " + clazz.getName()); - } - - if (!any.is(entityToProto.get(clazz))) { + if (!any.is(protoClass)) { throw new IOException("Invalid proto for entity " + clazz.getName()); } - try { - Class protoClazz = entityToProto.get(clazz); - Message anyMessage = any.unpack(protoClazz); - return fromProto(anyMessage); - } catch (Exception e) { - throw new IOException("Failed to deserialize entity " + clazz.getName(), e); - } + Message anyMessage = any.unpack(protoClass); + return fromProto(anyMessage, clazz, classLoader); } - public M toProto(T t) throws IOException { - if (!entityToSerDe.containsKey(t.getClass())) { - throw new IOException("No serde found for entity " + t.getClass().getName()); + private ProtoSerDe getProtoSerde( + Class entityClass, ClassLoader classLoader) throws IOException { + if (!ENTITY_TO_SERDE.containsKey(entityClass.getCanonicalName()) + || ENTITY_TO_SERDE.get(entityClass.getCanonicalName()) == null) { + throw new IOException("No serde found for entity " + entityClass.getCanonicalName()); } - - ProtoSerDe protoSerDe = (ProtoSerDe) entityToSerDe.get(t.getClass()); - return protoSerDe.serialize(t); + return (ProtoSerDe) + entityToSerDe.computeIfAbsent( + entityClass, + k -> { + try { + Class> serdeClazz = + (Class>) + loadClass(ENTITY_TO_SERDE.get(k.getCanonicalName()), classLoader); + return serdeClazz.newInstance(); + } catch (Exception e) { + throw new RuntimeException( + "Failed to instantiate serde class " + k.getCanonicalName(), e); + } + }); } - public T fromProto(M m) throws IOException { - if (!protoToEntity.containsKey(m.getClass())) { - throw new IOException("No entity class found for proto " + m.getClass().getName()); + private Class getProtoClass( + Class entityClass, ClassLoader classLoader) throws IOException { + if (!ENTITY_TO_PROTO.containsKey(entityClass.getCanonicalName()) + || ENTITY_TO_PROTO.get(entityClass.getCanonicalName()) == null) { + throw new IOException("No proto class found for entity " + entityClass.getCanonicalName()); } - Class entityClass = protoToEntity.get(m.getClass()); + return entityToProto.computeIfAbsent( + entityClass, + k -> { + try { + return (Class) + loadClass(ENTITY_TO_PROTO.get(k.getCanonicalName()), classLoader); + } catch (Exception e) { + throw new RuntimeException("Failed to create proto class " + k.getCanonicalName(), e); + } + }); + } - if (!entityToSerDe.containsKey(entityClass)) { - throw new IOException("No serde found for entity " + entityClass.getName()); - } + private M toProto(T t, ClassLoader classLoader) + throws IOException { + ProtoSerDe protoSerDe = (ProtoSerDe) getProtoSerde(t.getClass(), classLoader); + return protoSerDe.serialize(t); + } - ProtoSerDe protoSerDe = (ProtoSerDe) entityToSerDe.get(entityClass); + private T fromProto( + M m, Class entityClass, ClassLoader classLoader) throws IOException { + ProtoSerDe protoSerDe = getProtoSerde(entityClass, classLoader); return protoSerDe.deserialize(m); } @@ -131,7 +118,8 @@ private Class loadClass(String className, ClassLoader classLoader) throws IOE try { return Class.forName(className, true, classLoader); } catch (Exception e) { - throw new IOException("Failed to load class " + className, e); + throw new IOException( + "Failed to load class " + className + " with classLoader " + classLoader, e); } } } diff --git a/core/src/test/java/com/datastrato/graviton/proto/TestEntityProtoSerDe.java b/core/src/test/java/com/datastrato/graviton/proto/TestEntityProtoSerDe.java index 29b5f7eaeca..437e1c9d3e0 100644 --- a/core/src/test/java/com/datastrato/graviton/proto/TestEntityProtoSerDe.java +++ b/core/src/test/java/com/datastrato/graviton/proto/TestEntityProtoSerDe.java @@ -34,15 +34,10 @@ public void testAuditInfoSerDe() throws IOException { ProtoEntitySerDe protoEntitySerDe = (ProtoEntitySerDe) entitySerDe; - AuditInfo auditInfoProto = protoEntitySerDe.toProto(auditInfo); - Assertions.assertEquals(creator, auditInfoProto.getCreator()); - Assertions.assertEquals(now, ProtoUtils.toInstant(auditInfoProto.getCreateTime())); - Assertions.assertEquals(modifier, auditInfoProto.getLastModifier()); - Assertions.assertEquals(now, ProtoUtils.toInstant(auditInfoProto.getLastModifiedTime())); - - com.datastrato.graviton.meta.AuditInfo auditInfoFromProto = - protoEntitySerDe.fromProto(auditInfoProto); - Assertions.assertEquals(auditInfo, auditInfoFromProto); + byte[] bytes = protoEntitySerDe.serialize(auditInfo); + com.datastrato.graviton.meta.AuditInfo auditInfoFromBytes = + protoEntitySerDe.deserialize(bytes, com.datastrato.graviton.meta.AuditInfo.class); + Assertions.assertEquals(auditInfo, auditInfoFromBytes); // Test with optional fields com.datastrato.graviton.meta.AuditInfo auditInfo1 = @@ -51,19 +46,10 @@ public void testAuditInfoSerDe() throws IOException { .withCreateTime(now) .build(); - AuditInfo auditInfoProto1 = protoEntitySerDe.toProto(auditInfo1); - - Assertions.assertEquals(creator, auditInfoProto1.getCreator()); - Assertions.assertEquals(now, ProtoUtils.toInstant(auditInfoProto1.getCreateTime())); - - com.datastrato.graviton.meta.AuditInfo auditInfoFromProto1 = - protoEntitySerDe.fromProto(auditInfoProto1); - Assertions.assertEquals(auditInfo1, auditInfoFromProto1); - // Test from/to bytes - byte[] bytes = entitySerDe.serialize(auditInfo1); - com.datastrato.graviton.meta.AuditInfo auditInfoFromBytes = - entitySerDe.deserialize(bytes, com.datastrato.graviton.meta.AuditInfo.class); + bytes = protoEntitySerDe.serialize(auditInfo1); + auditInfoFromBytes = + protoEntitySerDe.deserialize(bytes, com.datastrato.graviton.meta.AuditInfo.class); Assertions.assertEquals(auditInfo1, auditInfoFromBytes); } @@ -94,12 +80,6 @@ public void testEntitiesSerDe() throws IOException { ProtoEntitySerDe protoEntitySerDe = (ProtoEntitySerDe) entitySerDe; - Metalake metalakeProto = protoEntitySerDe.toProto(metalake); - Assertions.assertEquals(props, metalakeProto.getPropertiesMap()); - com.datastrato.graviton.meta.BaseMetalake metalakeFromProto = - protoEntitySerDe.fromProto(metalakeProto); - Assertions.assertEquals(metalake, metalakeFromProto); - byte[] metalakeBytes = protoEntitySerDe.serialize(metalake); com.datastrato.graviton.meta.BaseMetalake metalakeFromBytes = protoEntitySerDe.deserialize( @@ -115,15 +95,10 @@ public void testEntitiesSerDe() throws IOException { .withVersion(version) .build(); - Metalake metalakeProto1 = protoEntitySerDe.toProto(metalake1); - Assertions.assertEquals(0, metalakeProto1.getPropertiesCount()); - com.datastrato.graviton.meta.BaseMetalake metalakeFromProto1 = - protoEntitySerDe.fromProto(metalakeProto1); - Assertions.assertEquals(metalake1, metalakeFromProto1); - - byte[] metalakeBytes1 = entitySerDe.serialize(metalake1); + byte[] metalakeBytes1 = protoEntitySerDe.serialize(metalake1); com.datastrato.graviton.meta.BaseMetalake metalakeFromBytes1 = - entitySerDe.deserialize(metalakeBytes1, com.datastrato.graviton.meta.BaseMetalake.class); + protoEntitySerDe.deserialize( + metalakeBytes1, com.datastrato.graviton.meta.BaseMetalake.class); Assertions.assertEquals(metalake1, metalakeFromBytes1); // Test CatalogEntity @@ -141,11 +116,6 @@ public void testEntitiesSerDe() throws IOException { .withAuditInfo(auditInfo) .build(); - Catalog catalogProto = protoEntitySerDe.toProto(catalogEntity); - com.datastrato.graviton.meta.CatalogEntity catalogEntityFromProto = - protoEntitySerDe.fromProto(catalogProto); - Assertions.assertEquals(catalogEntity, catalogEntityFromProto); - byte[] catalogBytes = protoEntitySerDe.serialize(catalogEntity); com.datastrato.graviton.meta.CatalogEntity catalogEntityFromBytes = protoEntitySerDe.deserialize(