Understanding TensorFlow's @tf.function Decorator

Introduction

Improving the performance of a training loop can save hours of computing time when training machine learning models. One of the ways of improving the performance of TensorFlow code is using the tf.function() decorator - a simple, one-line change that can make your functions run significantly faster.

In this short guide, we will explain how tf.function() improves performance and take a look at some best practices.

Python Decorators and tf.function()

In Python, a decorator is a function that modifies the behavior of other functions. For instance, suppose you call the following function in a notebook cell:

import tensorflow as tf

x = tf.random.uniform(shape=[100, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32)

def some_costly_computation(x):
    aux = tf.eye(100, dtype=tf.dtypes.float32)
    result = tf.zeros(100, dtype = tf.dtypes.float32)
    for i in range(1,100):
        aux = tf.matmul(x,aux)/i
        result = result + aux
    return result

%timeit some_costly_computation(x)
16.2 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

However, if we pass the costly function into a tf.function():

quicker_computation = tf.function(some_costly_computation)
%timeit quicker_computation(x)

We get quicker_computation() - a new function that performs much faster than the previous one:

4.99 ms ± 139 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

So, tf.function() modifies some_costly_computation() and outputs the quicker_computation() function. Decorators also modify functions, so it was natural to make tf.function() a decorator as well.

Using the decorator notation is the same as calling tf.function(function):

@tf.function
def quick_computation(x):
  aux = tf.eye(100, dtype=tf.dtypes.float32)
  result = tf.zeros(100, dtype = tf.dtypes.float32)
  for i in range(1,100):
    aux = tf.matmul(x,aux)/i
    result = result + aux
  return result

%timeit quick_computation(x)
5.09 ms ± 283 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

How Does tf.function() Work?

How come we can make certain functions run 2-3x faster?

TensorFlow code can be run in two modes: eager mode and graph mode. Eager mode is the standard, interactive way to run code: every time you call a function, it is executed.

Graph mode, however, is a little bit different. In graph mode, before executing the function, TensorFlow creates a computation graph, which is a data structure containing the operations required for executing the function. The computation graph allows TensorFlow to simplify the computations and find opportunities for parallelization. The graph also isolates the function from the overlying Python code, allowing it to be run efficiently on many different devices.

A function decorated with @tf.function is executed in two steps:

  1. In the first step, TensorFlow executes the Python code for the function and compiles a computation graph, delaying the execution of any TensorFlow operation.
  2. Afterwards, the computation graph is run.

Note: The first step is known as "tracing".

The first step will be skipped if there is no need to create a new computation graph. This improves the performance of the function but also means that the function will not execute like regular Python code (in which each executable line is executed). For example, let's modify our previous function:

@tf.function
def quick_computation(x):
  print('Only prints the first time!')
  aux = tf.eye(100, dtype=tf.dtypes.float32)
  result = tf.zeros(100, dtype = tf.dtypes.float32)
  for i in range(1,100):
    aux = tf.matmul(x,aux)/i
    result = result + aux
  return result

quick_computation(x)
quick_computation(x)

This results in:

Only prints the first time!

The print() is only executed once during the tracing step, which is when regular Python code is run. The next calls to the function only execute TenforFlow operations from the computation graph (TensorFlow operations).

However, if we use tf.print() instead:

@tf.function
def quick_computation_with_print(x):
  tf.print("Prints every time!")
  aux = tf.eye(100, dtype=tf.dtypes.float32)
  result = tf.zeros(100, dtype = tf.dtypes.float32)
  for i in range(1,100):
    aux = tf.matmul(x,aux)/i
    result = result + aux
  return result

quick_computation_with_print(x)
quick_computation_with_print(x)

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

Prints every time!
Prints every time!

TensorFlow includes tf.print() in its computation graph as it's a TensorFlow operation - not a regular Python function.

Warning: Not all Python code is executed in every call to a function decorated with @tf.function. After tracing, only the operations from the computational graph are run, which means some care must be taken in our code.

Best Practices with @tf.function

Writing Code with TensorFlow Operations

As we've just shown, some parts of the code are ignored by the computation graph. This makes it hard to predict the behavior of the function when coding with "normal" Python code, as we've just seen with print(). It is better to code your function with TensorFlow operations when applicable to avoid unexpected behavior.

For instance, for and while loops may or may not be converted into the equivalent TensorFlow loop. Therefore, it is better to write your "for" loop as a vectorized operation, if possible. This will improve the performance of your code and ensure that your function traces correctly.

As an example, consider the following:

x = tf.random.uniform(shape=[100, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32)

@tf.function
def function_with_for(x):
    summ = float(0)
    for row in x:
      summ = summ + tf.reduce_mean(row)
    return summ

@tf.function
def vectorized_function(x):
  result = tf.reduce_mean(x, axis=0)
  return tf.reduce_sum(result)


print(function_with_for(x))
print(vectorized_function(x))

%timeit function_with_for(x)
%timeit vectorized_function(x)
tf.Tensor(0.672811, shape=(), dtype=float32)
tf.Tensor(0.67281103, shape=(), dtype=float32)
1.58 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
440 µs ± 8.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The code with the TensorFlow operations is considerably faster.

Avoid References to Global Variables

Consider the following code:

x = tf.Variable(2, dtype=tf.dtypes.float32)
y = 2

@tf.function
def power(x):
  return tf.pow(x,y)

print(power(x))

y = 3

print(power(x))
tf.Tensor(4.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

The first time the decorated function power() was called, the output value was the expected 4. However, the second time, the function ignored that the value of y was changed. This happens because the value of Python global variables is frozen for the function after tracing.

A better way would be to use tf.Variable() for all your variables and pass both as arguments to your function.

x = tf.Variable(2, dtype=tf.dtypes.float32)
y = tf.Variable(2, dtype = tf.dtypes.float32)

@tf.function
def power(x,y):
  return tf.pow(x,y)

print(power(x,y))

y.assign(3)

print(power(x,y))
tf.Tensor(4.0, shape=(), dtype=float32)
tf.Tensor(8.0, shape=(), dtype=float32)

Debugging [email protected]_s

In general, you want to debug your function in eager mode, and then decorate them with @tf.function after your code is running correctly because the error messages in eager mode are more informative.

Some common problems are type errors and shape errors. Type errors happen when there is a mismatch in the type of the variables involved in an operation:

x = tf.Variable(1, dtype = tf.dtypes.float32)
y = tf.Variable(1, dtype = tf.dtypes.int32)

z = tf.add(x,y)
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a int32 tensor [Op:AddV2]

Type errors easily creep in, and can easily be fixed by casting a variable to a different type:

y = tf.cast(y, tf.dtypes.float32)
z = tf.add(x, y) 
tf.print(z) # 2

Shape errors happen when your tensors do not have the shape your operation require:

x = tf.random.uniform(shape=[100, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32)
y = tf.random.uniform(shape=[1, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32)

z = tf.matmul(x,y)
InvalidArgumentError: Matrix size-incompatible: In[0]: [100,100], In[1]: [1,100] [Op:MatMul]

One convenient tool for fixing both kinds of errors is the interactive Python debugger, which you can call automatically in a Jupyter Notebook using %pdb. Using that, you can code your function and run it through some common use cases. If there is an error, an interactive prompt opens. This prompt allows you to go up and down the abstraction layers in your code and check the values, types, and shapes of your TensorFlow variables.

Conclusion

We've seen how TensorFlow's tf.function() makes your function more efficient, and how the @tf.function decorator applies the function to your own.

This speed-up is useful in functions that will be called many times, such as custom training steps for machine learning models.

Was this article helpful?

© 2013-2025 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms