Streamlit: How to use cache with TensorFlow

Posted on Sat 29 May 2021 in streamlit

Introduction

A few days ago, I trained an object detection model using the Object Detection API of TensorFlow. The process is pretty straightforward (see this example and this guide):

  1. Label images using labelImg.
  2. Convert training images to TF Records (a binary format rather than JPEG files).
  3. Download a pre-trained model from the Model Zoo.
  4. Configure your pre-trained model by editing a text file.
  5. Train the model.
  6. Export the model.
  7. Test the model.

Now that you have a model trained on your own dataset, you may want to create a demo to show off its potential. To do so, we can use Streamlit. If you are new to that library, I highly recommend to check Data Professor channel in YouTube, specifically this tutorial.

Below, we explain how to create a Streamlit application with a TensorFlow model. I used the repository created by bourdakos1 for training a object detecion model for detecting spaceships of Star Wars. Run these commands to get a copy:

$ cd ~/git
$ git clone https://github.com/bourdakos1/Custom-Object-Detection.git
$ git checkout 3ea7c53     # we use this specific revision

Take a look of the object_detection_runner.py script located in object_detection. We will use this script as a basis for creating our Streamlit application:

object_detection/object_detection_runner.py

I copy the content of object_detection_runner.py at the bottom of the post. For now, let's focus on the interesting bits. The following code loads the model into memory:

# model (aka checkpoint)
PATH_TO_CKPT = 'output_inference_graph/frozen_inference_graph.pb'

# Load model into memory
print('Loading model...')
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

We need to load the model before doing any prediction. This is crucial for our Streamlit application. We should not load the model every time a prediction is requested. Instead, we must load it just once and store the model in the cache. To so do, we will use the @st.cache decorator as explained here, here and here.

The next piece of code shows the detection of objects using test images:

print('detecting...')
with detection_graph.as_default():  # graph loaded in the previous block
    with tf.Session(graph=detection_graph) as sess:

        # get tensors from the detection graph
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
        detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = detection_graph.get_tensor_by_name('num_detections:0')

        for image_path in TEST_IMAGE_PATHS:
            # read the test image and 
            # creates an output image with bounding boxes
            detect_objects(image_path)

Intuitively, we need to run the previous block every time we want to detect objects in a image. In other words, we do not need to cache that block (Maybe we could cache the five tensors but I did not. In this case, storing detection_graph in the cache results in a huge improvement on the RAM usage. Storing other tensors may result in negligible improvements).

Basic caching

Now that we highlighted the important sections of object_detection_runner.py, we can talk about caching in Streamlit. From the documentation, we can use the @st.cache decorator for storing objects in Streamlit. Consider the following example:

# basic_app.py
# Run
# streamlit run basic_app.py
import streamlit as st
import time

def expensive_computation(a, b):
    time.sleep(10)
    return a*b

a = 2
b = 21

with st.spinner("Running..."):
    res = expensive_computation(a, b)
st.success("Done!")
st.write("Result:", res)

Here, expensive_computation() takes a few seconds before returning a result. We included st.spinner() and st.success() for aesthetics purposes. Once expensive_computation() is done, we can load the rest of the application (e.g., printing the result).

This is the result:

If you refresh the application by hitting either R or F5, you will notice that it takes a few seconds to show the result each time. This is because expensive_computation() is being re-executed every time the app runs. To avoid this issue, we can cache that function by using the @st.cache decorator as follows:

@st.cache
def expensive_computation(a, b):
    time.sleep(2)
    return a*b

The rest of the code is unchanged. Now, run the application. The first time it will take a few seconds before printing the result. The second time, it will be faster. Notice that there is an additional message:

You may want to print something from a cached function. However, since the function is cached, such a message will be printed only once. Let's check an example. Modify the function as follows:

@st.cache
def expensive_computation(a, b):
    st.write("Cache miss: expensive_computation(%d, %d) ran" % (a, b))   # add this
    time.sleep(2)
    return a*b

Reload the application. You will see this warning message:

CachedStFunctionWarning: Your script uses st.markdown() or st.write() to write to your Streamlit app from within some cached code at expensive_computation(). This code will only be called when we detect a cache "miss", which can lead to unexpected results.

How to fix this:
    - Move the st.markdown() or st.write() call outside expensive_computation().
    - Or, if you know what you're doing, use @st.cache(suppress_st_warning=True) to suppress the warning.

Now, reload the application one more time. Both the warning (yellow) message and the "Cache miss" message are gone. You will only see the success (green) message and the result. This is what happened:

  • We use st.write for printing "Cache miss" from expensive_computation().
  • We cache expensive_computation() using @st.cache (with no additional parameters).
  • The first time you call expensive_computation(), you will see "Cache miss" on the screen and Streamlit will cache the function.
  • The next time you reload the application, you will not see "Cache miss".

In short, we are trying to print something from a cached function. As a result, the "Cache miss" message will be printed only once. Streamlit will warn us about this. We can use the suppress_st_warning option to hide the warning (yellow) message.

@st.cache(suppress_st_warning=True)
def expensive_computation(a, b):
    st.write("Cache miss: expensive_computation(%d, %d) ran" % (a, b))   # add this
    time.sleep(30)
    return a*b

You only will see the success message and the result:

If we change the body of the function or its input, Streamlit will update the cache. In fact, Streamlit checks a few things when using the @st.decorator:

  1. The input parameters that you called the function with
  2. The value of any external variable used in the function
  3. The body of the function
  4. The body of any function used inside the cached function

If this is the first time Streamlit has seen these four components with these exact values and in this exact combination and order, it runs the function and stores the result in a local cache. Then, next time the cached function is called, if none of these components changed, Streamlit will just skip executing the function altogether and, instead, return the output previously stored in the cache.

Documentation

For instance, if you change the input parameters from:

a = 2
b = 21

to:

a = 22
b = 21

Streamlit will show "Running expensive_computation(...)" again. In other words, it needs to re-execute expensive_computation() to update the cached result.

Advanced caching

In this section, we are going to talk about the allow_output_mutation option of @st.cache. To explain it, we need to cover how Streamlit manages the cache. From the documentation, Streamlit uses a key-value store for the cache. The key is a hash of four elements:

  • The input parameters that you called the function with
  • The value of any external variable used in the function
  • The body of the function
  • The body of any function used inside the cached function

If any of those elements change, then Streamlit generates a new key. Thus, the key of expensive_computation(a=2, b=21) is different to the key of expensive_computation(a=3, b=21).

The value is a tuple that contains:

  • The output of the function itself, that is output = expensive_computation(a, b).
  • The hash of output. This is useful when the output is mutable.

In short, the cache is a dictionary-like store that contains key -> (output, output_hash) items.

Streamlit follows the next procedure when looking for a key-value in the cache:

  1. Computes the cache key.
  2. If the key is found in the cache, then:
    1. Extracts the previously-cached (output, output_hash) tuple.
    2. Performs an Output Mutation Check, where a fresh hash of the output is computed and compared to the stored output_hash.
      1. If the two hashes are different, shows a Cached Object Mutated warning. (Note: Setting allow_output_mutation=True disables this step).
  3. If the input key is not found in the cache, then:
  4. Executes the cached function (i.e., output = expensive_computation(2, 21)).
  5. Calculates the output_hash from the function's output.
  6. Stores key -> (output, output_hash) in the cache.

  7. Returns the output.

Let's focus on the Output Mutation Check. We used this code in the previous section:

@st.cache(suppress_st_warning=True)
def expensive_computation(a, b):
    st.write("Cache miss: expensive_computation(%d, %d) ran" % (a, b))
    time.sleep(2)
    return a*b

The output of expensive_computation() is an integer. In other words, the output is immutable. Now, let's analyze what happen when we return a mutable output:

@st.cache(suppress_st_warning=True)
def expensive_computation(a, b):
    st.write("Cache miss: expensive_computation(%d, %d) ran" % (a, b))
    time.sleep(2)
    #return a*b             # immutable
    return {"output": a*b}  # mutable

This is the full code:

# basic_app.py
# Run
# streamlit run basic_app.py
import streamlit as st
import time

@st.cache(suppress_st_warning=True)
def expensive_computation(a, b):
    st.write("Cache miss: expensive_computation(%d, %d) ran" % (a, b))   # add this
    time.sleep(2)
    #return a*b             # immutable
    return {"output": a*b}  # mutable

a = 2
b = 21

with st.spinner("Wait..."):
    res = expensive_computation(a, b)
st.success("Done!")
st.write("Result:", res)

and this is the result:

So far so good. Now, let's modify the output of the cached function outside the function:

# mutate the output
res["output"] = "nothing"
st.write("Mutated Result:", res)

This is the full code:

# basic_app.py
# Run
# streamlit run basic_app.py
import streamlit as st
import time

@st.cache(suppress_st_warning=True)
def expensive_computation(a, b):
    st.write("Cache miss: expensive_computation(%d, %d) ran" % (a, b))   # add this
    time.sleep(2)
    #return a*b             # immutable
    return {"output": a*b}  # mutable

a = 2
b = 21

with st.spinner("Wait..."):
    res = expensive_computation(a, b)
st.success("Done!")
st.write("Result:", res)

# mutate the output
res["output"] = "nothing"
st.write("Mutated Result:", res)

You will see this the first time you run the application:

Now, reload the application. You will see a warning before showing output:

Also, note that output has been changed in both messages (Result and Mutated Result):

In short, be careful when accessing the output of cached functions. Remember, Streamlit uses key-value pairs in the cache, key -> (output, output_hash) If you happen to modify output, you will modify the output_hash as well, altering the cache.

According to the documentation:

What's going on here is that Streamlit caches the output res by reference. When you mutated res["output"] outside the cached function you ended up inadvertently modifying the cache. This means every subsequent call to expensive_computation(2, 21) will return the wrong value!

We got the correct output the first time we ran the application. We then modified the output of the cached function. Since the output is mutable, we altered the cached output as well.

What can we do to fix this issue?

  • Do not return mutable objects in the cached function. Instead of returning {"output": a*b}, return a*b. Although, this is not an option in some cases. For instance, we may need to return a TensorFlow model.
  • Do not change a mutable output. In this case, there is no need to manually change res["ouput"]. It is just an example, though.
  • Clone the output before modifying it. Something like:
from copy import deepcopy
res = deepcopy(expensive_computation(a, b))
  • Allow the cached output to mutate. If you want to allow the cached object to mutate, you can disable the Output Mutation Check by setting allow_output_mutation=True like this:
@st.cache(allow_output_mutation=True)
def expensive_computation(a, b):
    # ...
    return mutable_output

We will use the last option for caching a TensorFlow model in the next section.

Caching a TensorFlow model

Our object detection application relies on a TensorFlow model (i.e., a graph). This graph is mutable, so expect a function like this:

def load_model():
    detection_graph = tf.Graph()
    # read graph from file...
    return detection_graph

Our aim is to create a function for loading a model from disk into memory only once. Otherwise, we must load it every time we run the application, which will consume our RAM quickly.

I found this example in the Gallery of Streamlit. This is the relevant part of the code:

@st.cache(suppress_st_warning=True)
def set_transform(content):
    try:
        transform = eval(content, {"kornia": kornia, "nn": nn}, None)
    except Exception as e:
        st.write(f"There was an error: {e}")
        transform = nn.Sequential()
    return transform

# load page...

# load transform (e.g., model)
transform = set_transform(content)

We use that as a basis for loading our model. However, that approach does not work in my case. After digging a bit more, I found this other solution:

from keras import backend as K

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    session = K.get_session()
    return model, session

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model, session = load_model()
    if sentence:
        K.set_session(session)
        y_hat = model.predict(sentence)

It turns out we need to specify these options in @st.cache for loading a model:

  • suppress_st_warning=True
  • allow_output_mutation=True

This is the function for loading the TensorFlow model:

# Load model into memory
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def set_detection_graph():
    detection_graph = tf.Graph()
    with st.spinner("Loading model for the first time..."):

        with detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')

    return detection_graph

These are the options:

  • suppress_st_warning=True indicates Streamlit that we want to display (print) messages from the cached function. Thus, warning messages are not shown.
  • allow_output_mutation=True indicates Streamlit that detection_graph can be modified outside of the set_detection_graph() function (although we do not do it, though).

Now, we are ready for loading the model and detect objects:

if st.button("Detect objects"):

    with st.spinner("Detecting..."):

        # --- detection ---

        # load cached graph
        detection_graph = set_detection_graph()

        with detection_graph.as_default():
            with tf.Session(graph=detection_graph) as sess:

                # load tensors from the graph
                image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
                detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = detection_graph.get_tensor_by_name('num_detections:0')

                output_image = detect_objects(input_image)

                st.image(output_image)

        # --- detection ---

    st.success("Done!")