Original Source Here
Deploying a TensorFlow Model with a Java Spring Boot and Vue.js Frontend
Seamlessly Serve TensorFlow Models, consume tham in Java Backend with an Vue.js Frontend
Deploying a TensorFlow model in a production environment can be challenging. This article guides you through the process of deploying a TensorFlow model using TensorFlow Serving and creating a Java Spring Boot backend with a Vue.js frontend. We will cover the following steps:
- Creating a Java Spring Boot backend that connects to the TensorFlow model server
- Preparing and deploying the TensorFlow model server
- Building a Vue.js frontend to interact with the backend
Prerequisites
- Basic knowledge of Java and deep learning models
- A TensorFlow model in the saved_model format that accepts jpegs/pngs (see Appendix below)
- An AWS account and a running instance with Deep Learning AMI
- A sample spring boot application
Creating a Java Spring Boot backend
We will start by creating a Java Spring Boot application that connects to the TensorFlow model server. The backend will have an endpoint to upload images in JPEG or PNG format. We will resize the images if they are larger than 1024×1024 pixels and convert them into base64 format. The images will be sent to the TensorFlow model server, which returns an embedding vector for each image. We will then return the prediction vector to the user as plain text.
- Create a new Java Spring Boot project and add the necessary dependencies for Spring Boot and TensorFlow Serving.
- Create a Java class called TensorFlowClient that connects to the TensorFlow model server’s REST API. This class should handle sending the base64 encoded images as a JSON array to the server and converting the returned embedding vectors into a Java double array stored in a HashMap.
- Implement a controller class called ImageController that handles the image upload, resizing, and conversion. The controller should call the TensorFlowClient class to send the image to the TensorFlow model server and receive the prediction vector.
- Note that the resizing can be done in the model as well, but to avoid sending very large images to the server, we do it here in Java (which is quite inefficient as well tbh)
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Base64;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.lang.reflect.Type;
import java.util.List;
public class TensorFlowClient {
private static final String API_URL = "http://SERVER_IP_OR_HOST:9941/v1/models/model:predict";
private static List<double[]> getEmbeddings(String base64Image) {
List<double[]> embeddings = new LinkedHashMap<>();
try {
String jsonString = sendRequest(base64Image);
embeddings = parseResponse(jsonString);
} catch (Exception e) {
e.printStackTrace();
}
return embeddings;
}
private static String sendRequest(String base64Image) throws Exception {
URL url = new URL(API_URL);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json");
connection.setDoOutput(true);
String input = String.format("{\"instances\": [\"%s\"]}", base64Image);
OutputStream outputStream = connection.getOutputStream();
outputStream.write(input.getBytes());
outputStream.flush();
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
String line;
StringBuilder response = new StringBuilder();
while ((line = bufferedReader.readLine()) != null) {
response.append(line);
}
bufferedReader.close();
connection.disconnect();
return response.toString();
}
private static List<double[]> parseResponse(String jsonString) {
Gson gson = new Gson();
Type listType = new TypeToken<List<List<Double>>>(){}.getType();
List<List<Double>> responseList = gson.fromJson(jsonString, listType);
Map<String, double[]> embeddings = new LinkedHashMap<>();
for (List<Double> embeddingList : responseList) {
double[] embedding = embeddingList.stream().mapToDouble(Double::doubleValue).toArray();
embeddings.add(embedding);
}
return embeddings;
}
}
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Map;
@SpringBootApplication
@RestController
public class ImageController {
private final TensorFlowClient tensorFlowClient = new TensorFlowClient();
@PostMapping("/image")
public ResponseEntity<String> uploadImage(@RequestParam("file") MultipartFile file) {
try {
String fileType = file.getContentType();
if (fileType == null || !fileType.matches("image/(jpeg|png)")) {
return ResponseEntity.badRequest().body("Invalid file type. Only JPEG and PNG are allowed.");
}
BufferedImage image = ImageIO.read(file.getInputStream());
image = resizeImage(image, 1024, 1024);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ImageIO.write(image, fileType.split("/")[1], outputStream);
outputStream.flush();
byte[] imageBytes = outputStream.toByteArray();
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
List<double[]> embeddings = tensorFlowClient.getEmbeddings(base64Image);
double[] predictionVector = embeddings.get(0);
String response = String.format("%s:%s", imageHash, arrayToString(predictionVector));
return ResponseEntity.ok(response);
} catch (IOException | NoSuchAlgorithmException e) {
e.printStackTrace();
return ResponseEntity.status(500).body("Internal server error");
}
}
private BufferedImage resizeImage(BufferedImage originalImage, int maxWidth, int maxHeight) {
int width = originalImage.getWidth();
int height = originalImage.getHeight();
double scaleFactor = Math.min((double) maxWidth / width, (double) maxHeight / height);
int newWidth = (int) (width * scaleFactor);
int newHeight = (int) (height * scaleFactor);
Image scaledImage = originalImage.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH);
BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, originalImage.getType());
Graphics2D graphics = resizedImage.createGraphics();
graphics.drawImage(scaledImage, 0, 0, null);
graphics.dispose();
return resizedImage;
}
private String arrayToString(double[] array) {
StringBuilder stringBuilder = new StringBuilder();
for (double value : array) {
stringBuilder.append(value).append(",");
}
return stringBuilder.deleteCharAt(stringBuilder.length() - 1).toString();
}
}
Start the server with
Preparing and deploying the TensorFlow model server
To deploy the TensorFlow model server, follow these steps:
- Transfer your TensorFlow model and configuration files to your AWS instance using scp.
scp -i /path/to/your/aws/key.pem -r my_model ubuntu@<your_aws_instance_ip>:/home/ubuntu/
scp -i /path/to/your/aws/key.pem tensorflow_model_server.config ubuntu@<your_aws_instance_ip>:/home/ubuntu/
- Install TensorFlow Serving or use the Deep Learning AMI Image, where it is installed
- Create a tensorflow_model_server.config file that specifies the model name, base path, and model version policy. Configure the server to listen on a specific port, e.g., 9941.
- Start the TensorFlow model server using the specified configuration file and port.
tensorflow_model_server --port=9941 --rest_api_port=9941 --model_config_file=/home/ubuntu/tensorflow_model_server.config
model_config_list: {
config: {
name: "my_model",
base_path: "/home/ubuntu/my_model",
model_platform: "tensorflow",
model_version_policy: {
latest {
num_versions: 1
}
},
version_labels {
key: "stable",
value: 1
}
}
}
In this example:
name
: The name of the model that you will use to access the model through the REST API. In this case, it’s “my_model”.base_path
: The path on the server where the model is stored. In this case, it’s “/home/ubuntu/my_model”, which is the path where you uploaded the model in a previous step.model_platform
: The platform of the model, which should be “tensorflow” for TensorFlow models.model_version_policy
: Specifies the version policy for the model. In this example, the server will serve only the latest version of the model.version_labels
: You can set labels for specific model versions. In this case, the label “stable” is set for version 1.
Building a Vue.js frontend
Finally, we will create a simple Vue.js frontend that allows users to upload images and receive predictions from the TensorFlow model server.
- Install Node.js, npm, and the Vue CLI.
npm install -g @vue/cli
2. Create a new Vue.js project and install the required dependencies.
vue create my-frontend
cd my-frontend
npm install axios
3. Implement the frontend by creating a Vue.js component that allows users to upload an image, and send the image to the Java Spring Boot backend using an HTTP request. Display the prediction result on the frontend.
(vue.config.js)
module.exports = {
devServer: {
proxy: {
"/image": {
target: "http://<your_server_ip>:9941",
changeOrigin: true,
pathRewrite: {
"^/image": "/image",
},
},
},
},
};
(src/App.vue)
<template>
<div id="app">
<h1>Image Prediction</h1>
<input type="file" @change="uploadImage" />
<button @click="submitImage" :disabled="!image">Submit</button>
<div v-if="prediction">
<h2>Prediction Result</h2>
<p><strong>Embedding Vector:</strong> </p>
</div>
</div>
</template>
<script>
import axios from "axios";
export default {
data() {
return {
image: null,
prediction: null,
};
},
methods: {
uploadImage(event) {
this.image = event.target.files[0];
},
async submitImage() {
if (!this.image) {
alert("Please select an image");
return;
}
const formData = new FormData();
formData.append("file", this.image);
try {
const response = await axios.post("/image", formData, {
headers: {
"Content-Type": "multipart/form-data",
},
});
const vector = response.data.split(":");
this.prediction = vector;
} catch (error) {
console.error(error);
alert("An error occurred while processing the image");
}
},
},
};
</script>
<style>
#app {
font-family: Avenir, Helvetica, Arial, sans-serif;
text-align: center;
color: #2c3e50;
margin-top: 60px;
}
</style>
4. Start the development server and access the Vue.js frontend.
npm run serve
Conclusion
In this article, we have demonstrated how to deploy a TensorFlow model using TensorFlow Serving and create a Java Spring Boot backend with a Vue.js frontend to interact with the model server. By following these steps, you can create a robust and scalable solution for deploying your TensorFlow models and serving predictions to users.
Appendix:
Make a tensorflow 2 model accept jpeg/pngs (see more details)
import tensorflow as tf
class Base64DecoderLayer(tf.keras.layers.Layer):
"""
Convert a incoming base 64 string into an bitmap with rgb values between 0 and 1
target_size e.g. [width,height]
"""
def __init__(self, target_size):
self.target_size = target_size
super(Base64DecoderLayer, self).__init__()
def byte_to_img(self, byte_tensor):
# base64 decoding id done by tensorflow serve
imgs_map = tf.io.decode_jpeg(byte_tensor)
imgs_map.set_shape((None, None, 3))
img = tf.image.resize(imgs_map, self.target_size, method=tf.image.ResizeMethod.BICUBIC)
img = tf.cast(img, dtype=tf.float32) / 255
return img
def call(self, input, **kwargs):
return imgs_map = tf.map_fn(self.byte_to_img, input, dtype=tf.float32)
# base an existing model on that layer
base64_input = tf.keras.Input(shape=(), dtype=tf.string, name='base64_in')
x = Base64DecoderLayer([weight, height])(base64_input)
output_tensor = model(x)
my_model = tf.keras.Model(inputs=base64_input, outputs=output_tensor, name='MyModel')
# save the model in saved_model format
tf.keras.models.save_model(my_model, 'my_model')
AI/ML
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot