How to Use TensorFlow with Java

Introduction

Machine Learning is gaining popularity and usage over the globe. It has already drastically changed the way certain applications are built and will likely continue to be a huge (and increasing) part of our daily lives.

There's no sugarcoating it, Machine Learning isn't simple. It's pretty daunting and can seem very complex to many.

Companies such as Google took it upon themselves to bring Machine Learning concepts closer to developers and allow them to gradually, with major help, make their first steps.

Thus, frameworks such as TensorFlow were born.

What is TensorFlow?

TensorFlow is an open-source Machine Learning framework developed by Google in Python and C++.

It helps developers easily acquire data, prepare and train models, predict future states, and perform large-scale machine learning.

With it, we can train and run deep neural networks which are most often used for Optical Character Recognition, Image Recognition/Classification, Natural Language Processing, etc.

Tensors and Operations

TensorFlow is based on computational graphs, which you can imagine as a classic graph with nodes and edges.

Each node is referred to as an operation, and they take zero or more tensors in and produce zero or more tensors out. An operation can be very simple, such as basic addition, but they can also be very complex.

Tensors are depicted as edges of the graph, and are the core data unit. We perform different functions on these tensors as we feed them to operations. They can have a single or multiple dimensions, which are sometimes referred to as their ranks - (Scalar: rank 0, Vector: rank 1, Matrix: rank 2)

This data flows through the computational graph through tensors, impacted by operations - hence the name TensorFlow.

Tensors can store data in any number of dimensions, and there are three main types of tensors: placeholders, variables, and constants.

Installing TensorFlow

Using Maven, installing TensorFlow is as easy as including the dependency:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

If your device supports GPU support, then use these dependencies:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow</artifactId>
  <version>1.13.1</version>
</dependency>

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>libtensorflow_jni_gpu</artifactId>
  <version>1.13.1</version>
</dependency>

You can check the version of TensorFlow currently installed by using the TensorFlow object:

System.out.println(TensorFlow.version());

TensorFlow Java API

The Java API TensorFlow offers is contained within the org.tensorflow package. It's currently experimental so it's not guaranteed to be stable.

Please note that the only fully supported language for TensorFlow is Python and that the Java API isn't nearly as functional.

It introduces us to new classes, an interface, enum, and exception.

Classes

The new classes introduced through the API are:

  • Graph: A data flow graph representing a TensorFlow computation
  • Operation: A Graph node that performs computation on Tensors
  • OperationBuilder: A builder class for Operations
  • Output<T>: A symbolic handle to a tensor produced by an Operation
  • SavedModelBundle: Represents a model loaded from storage.
  • SavedModelBundle.Loader: Provides options for loading a SavedModel
  • Server: An in-process TensorFlow server, for use in distributed training
  • Session: Driver for Graph execution
  • Session.Run: Output tensors and metadata obtained when executing a session
  • Session.Runner: Run Operations and evaluate Tensors
  • Shape: The possibly partially known shape of a tensor produced by an operation
  • Tensor<T>: A statically typed multi-dimensional array whose elements are of a type described by T
  • TensorFlow: Static utility methods describing the TensorFlow runtime
  • Tensors: Type-safe factory methods for creating Tensor objects
Enum
  • DataType: Represents the type of elements in a Tensor as an enum
Interface
  • Operand<T>: Interface implemented by operands of a TensorFlow operation
Exception
  • TensorFlowException: Unchecked exception thrown when executing TensorFlow Graphs

If we compare all of this to the tf module in Python, there's an obvious difference. The Java API doesn't have nearly the same amount of functionality, at least for now.

Graphs

As mentioned before, TensorFlow is based on computational graphs - where org.tensorflow.Graph is Java's implementation.

Note: Its instances are thread-safe, though we need to explicitly release resources used by the Graph after we're finished with it.

Let's start off with an empty graph:

Graph graph = new Graph();

This graph doesn't mean much, it's empty. To do anything with it, we first need to load it up with Operations.

To load it up with operations, we use the opBuilder() method, which returns an OperationBuilder object that'll add the operations to our graph once we call the .build() method.

Constants

Let's add a constant to our graph:

Operation x = graph.opBuilder("Const", "x")
               .setAttr("dtype", DataType.FLOAT)
               .setAttr("value", Tensor.create(3.0f))
               .build(); 

Placeholders

Placeholders are a "type" of variable that don't have a value at declaration. Their values will be assigned at a later date. This allows us to build graphs with operations without any actual data:

Operation y = graph.opBuilder("Placeholder", "y")
        .setAttr("dtype", DataType.FLOAT)
        .build();

Functions

And now finally, to round this up, we need to add certain functions. These could be as simple as multiplication, division, or addition, or as complex as matrix multiplications. The same as before, we define functions using the .opBuilder() method:

Operation xy = graph.opBuilder("Mul", "xy")
  .addInput(x.output(0))
  .addInput(y.output(0))
  .build();         

Note: We're using output(0) as a tensor can have more than one output.

Graph Visualisation

Sadly, the Java API doesn't yet include any tools that allow you to visualize graphs as you would in Python. When the Java API gets updated, so will this article.

Sessions

As mentioned before, a Session is the driver for a Graph's execution. It encapsulates the environment in which Operations and Graphs are executed to compute Tensors.

What this means is that the tensors in our graph that we constructed don't actually hold any value, as we didn't run the graph within a session.

Let's first add the graph to a session:

Session session = new Session(graph);

Our computation simply multiples the x and y value. In order to run our graph and compute it, we fetch() the xy operation and feed it the x and y values:

Tensor tensor = session.runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);
System.out.println(tensor.floatValue());

Running this piece of code will yield:

10.0f

Saving Models in Python and Loading in Java

This may sound a bit odd, but since Python is the only well-supported language, the Java API still doesn't have the functionality to save models.

This means that the Java API is meant only for the serving use-case, at least until it's fully supported by TensorFlow. At least, we can train and save models in Python and then load them in Java to serve them, using the SavedModelBundle class:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); 
Tensor tensor = model.session().runner().fetch("xy").feed("x", Tensor.create(5.0f)).feed("y", Tensor.create(2.0f)).run().get(0);  

System.out.println(tensor.floatValue());

Conclusion

TensorFlow is a powerful, robust and widely-used framework. It's constantly being improved and lately introduced to new languages - including Java and JavaScript.

Although the Java API doesn't yet have nearly as much functionality as TensorFlow for Python, it can still serve as a good intro to TensorFlow for Java developers.