diff --git a/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java b/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java index e54a1303280..5a7ef1fc1b8 100644 --- a/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java +++ b/central/src/main/java/ai/djl/serving/central/HttpStaticFileServerInitializer.java @@ -13,6 +13,7 @@ package ai.djl.serving.central; import ai.djl.serving.central.handler.HttpStaticFileServerHandler; +import ai.djl.serving.central.handler.ModelDownloadHandler; import ai.djl.serving.central.handler.ModelMetaDataHandler; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; @@ -54,6 +55,7 @@ public void initChannel(SocketChannel ch) { pipeline.addLast(new HttpServerCodec()); pipeline.addLast(new HttpObjectAggregator(65536)); pipeline.addLast(new ChunkedWriteHandler()); + pipeline.addLast(new ModelDownloadHandler()); pipeline.addLast(new ModelMetaDataHandler()); pipeline.addLast(new HttpStaticFileServerHandler()); } diff --git a/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java new file mode 100644 index 00000000000..d6daf2944d6 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/handler/ModelDownloadHandler.java @@ -0,0 +1,81 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.handler; + +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.serving.central.http.BadRequestException; +import ai.djl.serving.central.responseencoder.HttpRequestResponse; +import ai.djl.serving.central.utils.ModelUri; +import ai.djl.serving.central.utils.NettyUtils; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; + +/** + * A handler to handle download requests from the ModelView. + * + * @author anfee1@morgan.edu + */ +public class ModelDownloadHandler extends SimpleChannelInboundHandler { + + HttpRequestResponse jsonResponse; + + /** Constructs a ModelDownloadHandler. */ + public ModelDownloadHandler() { + jsonResponse = new HttpRequestResponse(); + } + + /** + * Handles the deployment request by forwarding the request to the serving-instance. + * + * @param ctx the context + * @param request the full request + */ + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) + throws IOException, ModelNotFoundException { + QueryStringDecoder decoder = new QueryStringDecoder(request.uri()); + String modelName = NettyUtils.getParameter(decoder, "modelName", null); + String modelGroupId = NettyUtils.getParameter(decoder, "groupId", null); + String modelArtifactId = NettyUtils.getParameter(decoder, "artifactId", null); + CompletableFuture.supplyAsync( + () -> { + try { + if (modelName != null) { + return ModelUri.uriFinder( + modelArtifactId, modelGroupId, modelName); + } else { + throw new BadRequestException("modelName is mandatory."); + } + + } catch (IOException | ModelNotFoundException ex) { + throw new IllegalArgumentException(ex.getMessage(), ex); + } + }) + .exceptionally((ex) -> Collections.emptyMap()) + .thenAccept(uriMap -> jsonResponse.sendAsJson(ctx, request, uriMap)); + } + + /** {@inheritDoc} */ + @Override + public boolean acceptInboundMessage(Object msg) { + FullHttpRequest request = (FullHttpRequest) msg; + + String uri = request.uri(); + return uri.startsWith("/serving/models?"); + } +} diff --git a/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java b/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java new file mode 100644 index 00000000000..c4905078b63 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/http/BadRequestException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.http; + +/** Thrown when a bad HTTP request is received. */ +public class BadRequestException extends IllegalArgumentException { + + static final long serialVersionUID = 1L; + + /** + * Constructs an {@code BadRequestException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public BadRequestException(String message) { + super(message); + } + + /** + * Constructs an {@code BadRequestException} with the specified detail message and a root cause. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause root cause + */ + public BadRequestException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/central/src/main/java/ai/djl/serving/central/http/package-info.java b/central/src/main/java/ai/djl/serving/central/http/package-info.java new file mode 100644 index 00000000000..fb26e5c6c82 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/http/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains HTTP codes. */ +package ai.djl.serving.central.http; diff --git a/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java new file mode 100644 index 00000000000..59e97dffdfd --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/responseencoder/HttpRequestResponse.java @@ -0,0 +1,123 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.responseencoder; + +import ai.djl.modality.Classifications; +import ai.djl.modality.Classifications.ClassificationsSerializer; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.repository.Metadata; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonPrimitive; +import com.google.gson.JsonSerializer; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import java.lang.reflect.Modifier; + +/** + * Serialize to json and send the response to the client. + * + * @author erik.bamberg@web.de + */ +public class HttpRequestResponse { + + private static final Gson GSON_WITH_TRANSIENT_FIELDS = + new GsonBuilder() + .setDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") + .setPrettyPrinting() + .excludeFieldsWithModifiers(Modifier.STATIC) + .registerTypeAdapter(Classifications.class, new ClassificationsSerializer()) + .registerTypeAdapter(DetectedObjects.class, new ClassificationsSerializer()) + .registerTypeAdapter(Metadata.class, new MetaDataSerializer()) + .registerTypeAdapter( + Double.class, + (JsonSerializer) + (src, t, ctx) -> { + long v = src.longValue(); + if (src.equals(Double.valueOf(String.valueOf(v)))) { + return new JsonPrimitive(v); + } + return new JsonPrimitive(src); + }) + .create(); + + /** + * send a response to the client. + * + * @param ctx channel context + * @param request full request + * @param entity the response + */ + public void sendAsJson(ChannelHandlerContext ctx, FullHttpRequest request, Object entity) { + + String serialized = GSON_WITH_TRANSIENT_FIELDS.toJson(entity); + ByteBuf buffer = ctx.alloc().buffer(serialized.length()); + buffer.writeCharSequence(serialized, CharsetUtil.UTF_8); + + FullHttpResponse response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=UTF-8"); + boolean keepAlive = HttpUtil.isKeepAlive(request); + this.sendAndCleanupConnection(ctx, response, keepAlive); + } + + /** + * send content of a ByteBuffer as response to the client. + * + * @param ctx channel context + * @param buffer response buffer + */ + public void sendByteBuffer(ChannelHandlerContext ctx, ByteBuf buffer) { + + FullHttpResponse response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=UTF-8"); + this.sendAndCleanupConnection(ctx, response, false); + } + + /** + * If Keep-Alive is disabled, attaches "Connection: close" header to the response and closes the + * connection after the response being sent. + * + * @param ctx context + * @param response full response + * @param keepAlive is alive or not + */ + private void sendAndCleanupConnection( + ChannelHandlerContext ctx, FullHttpResponse response, boolean keepAlive) { + HttpUtil.setContentLength(response, response.content().readableBytes()); + if (!keepAlive) { + // We're going to close the connection as soon as the response is sent, + // so we should also make it clear for the client. + response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); + } + + ChannelFuture flushPromise = ctx.writeAndFlush(response); + + if (!keepAlive) { + // Close the connection as soon as the response is sent. + flushPromise.addListener(ChannelFutureListener.CLOSE); + } + } +} diff --git a/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java new file mode 100644 index 00000000000..b1bc9a6a691 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/utils/ModelUri.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.utils; + +import ai.djl.Application; +import ai.djl.repository.Artifact; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** A class to find the URIs when given a model name. */ +public final class ModelUri { + + // TODO: Use the artifact repository to create base URI + private static URI base = URI.create("https://mlrepo.djl.ai/"); + + private ModelUri() {} + + /** + * Takes in a model name, artifactId, and groupId to return a Map of download URIs. + * + * @param artifactId is the artifactId of the model + * @param groupId is the groupId of the model + * @param name is the name of the model + * @return a map of download URIs + * @throws IOException if the uri could not be found + * @throws ModelNotFoundException if Model can not be found + */ + public static Map uriFinder(String artifactId, String groupId, String name) + throws IOException, ModelNotFoundException { + Criteria criteria = + Criteria.builder() + .optModelName(name) + .optGroupId(groupId) + .optArtifactId(artifactId) + .build(); + Map> models = ModelZoo.listModels(criteria); + Map uris = new ConcurrentHashMap<>(); + models.forEach( + (app, list) -> { + list.forEach( + artifact -> { + for (Map.Entry entry : + artifact.getFiles().entrySet()) { + URI fileUri = URI.create(entry.getValue().getUri()); + URI baseUri = artifact.getMetadata().getRepositoryUri(); + if (!fileUri.isAbsolute()) { + fileUri = base.resolve(baseUri).resolve(fileUri); + } + uris.put(entry.getKey(), fileUri); + } + }); + }); + return uris; + } +} diff --git a/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java new file mode 100644 index 00000000000..ce0e3623cf5 --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/utils/NettyUtils.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.central.utils; + +import ai.djl.modality.Input; +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.QueryStringDecoder; +import io.netty.handler.codec.http.multipart.Attribute; +import io.netty.handler.codec.http.multipart.FileUpload; +import io.netty.handler.codec.http.multipart.InterfaceHttpData; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +/** A utility class that handling Netty request and response. */ +public final class NettyUtils { + + private NettyUtils() {} + + /** + * Returns the bytes for the specified {@code ByteBuf}. + * + * @param buf the {@code ByteBuf} to read + * @return the bytes for the specified {@code ByteBuf} + */ + public static byte[] getBytes(ByteBuf buf) { + if (buf.hasArray()) { + return buf.array(); + } + + byte[] ret = new byte[buf.readableBytes()]; + int readerIndex = buf.readerIndex(); + buf.getBytes(readerIndex, ret); + return ret; + } + + /** + * Reads the parameter's value for the key from the uri. + * + * @param decoder the {@code QueryStringDecoder} parsed from uri + * @param key the parameter key + * @param def the default value + * @return the parameter's value + */ + public static String getParameter(QueryStringDecoder decoder, String key, String def) { + List param = decoder.parameters().get(key); + if (param != null && !param.isEmpty()) { + return param.get(0); + } + return def; + } + + /** + * Read the parameter's integer value for the key from the uri. + * + * @param decoder the {@code QueryStringDecoder} parsed from uri + * @param key the parameter key + * @param def the default value + * @return the parameter's integer value + * @throws NumberFormatException exception is thrown when the parameter-value is not numeric. + */ + public static int getIntParameter(QueryStringDecoder decoder, String key, int def) { + String value = getParameter(decoder, key, null); + if (value == null || value.isEmpty()) { + return def; + } + return Integer.parseInt(value); + } + + /** + * Parses form data and added to the {@link Input} object. + * + * @param data the form data + * @param input the {@link Input} object to be added to + */ + public static void addFormData(InterfaceHttpData data, Input input) { + if (data == null) { + return; + } + try { + String name = data.getName(); + switch (data.getHttpDataType()) { + case Attribute: + Attribute attribute = (Attribute) data; + input.addData(name, attribute.getValue().getBytes(StandardCharsets.UTF_8)); + break; + case FileUpload: + FileUpload fileUpload = (FileUpload) data; + input.addData(name, getBytes(fileUpload.getByteBuf())); + break; + default: + throw new IllegalArgumentException( + "Except form field, but got " + data.getHttpDataType()); + } + } catch (IOException e) { + throw new AssertionError(e); + } + } +} diff --git a/central/src/main/java/ai/djl/serving/central/utils/package-info.java b/central/src/main/java/ai/djl/serving/central/utils/package-info.java new file mode 100644 index 00000000000..8bee987b03f --- /dev/null +++ b/central/src/main/java/ai/djl/serving/central/utils/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains utility classes that hand response and requests. */ +package ai.djl.serving.central.utils; diff --git a/central/src/main/webapp/components/DownloadButtons.jsx b/central/src/main/webapp/components/DownloadButtons.jsx new file mode 100644 index 00000000000..97374cf0e5d --- /dev/null +++ b/central/src/main/webapp/components/DownloadButtons.jsx @@ -0,0 +1,46 @@ +import React, { Component, useState, useEffect, useRef } from "react"; +import Button from '@material-ui/core/Button'; +import ReactDOM from 'react-dom'; + +import { makeStyles } from '@material-ui/core/styles'; +import axios from 'axios' + + +const useFetch = (model) => { + const [data, setData] = useState([]); + + useEffect(() => { + async function fetchData() { + + axios.get("http://"+window.location.host+"/serving/models?modelName="+model.name+"&artifactId="+model.metadata.artifactId+"&groupId="+model.metadata.groupId) + .then(function(response) { + let appdata = Object.keys(response.data).map(function(key) { + return { + key: key, + link: response.data[key] + }; + }); + setData(appdata); + console.log(appdata) + }) + } + fetchData(); + }, [model.modelName,model.metadata.artifactId,model.metadata.groupId]); + + return data; +}; + + + +export default function ModelDownloadButtons(props) { + const modelUris = useFetch(props.model); + return ( + <> + {Object.keys(modelUris).map((keys) => ( + + + ) + )} + + ); +} diff --git a/central/src/main/webapp/components/ModelView.jsx b/central/src/main/webapp/components/ModelView.jsx index 3cfeb6d0f51..f5db8853ae9 100644 --- a/central/src/main/webapp/components/ModelView.jsx +++ b/central/src/main/webapp/components/ModelView.jsx @@ -19,6 +19,7 @@ import Chip from '@material-ui/core/Chip'; import Divider from '@material-ui/core/Divider'; import DynForm from './DynForm'; +import ModelDownloadButtons from './DownloadButtons'; import axios from 'axios' @@ -186,6 +187,7 @@ export default function ModelView(props) { + @@ -250,6 +252,9 @@ export default function ModelView(props) { : } + + +