Skip to content

Commit 78f9b97

Browse files
tconklingtvst
authored andcommitted
Warn when calling st.foo from within st.cache (streamlit#185)
* caching.is_within_cached_function() * WIP * Fix tests * More tests * Updated warning text, from tvst * Fix missing whitespace * Update docstrings. * Rename caching-related functions and variables. * Undo some of the renaming in caching.py. user_func -> func
1 parent 404987d commit 78f9b97

File tree

12 files changed

+368
-92
lines changed

12 files changed

+368
-92
lines changed

e2e/scripts/st_in_cache_warning.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2018-2019 Streamlit Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import streamlit as st
17+
18+
19+
@st.cache
20+
def cached_write(value):
21+
st.write(value)
22+
23+
24+
@st.cache(suppress_st_warning=True)
25+
def cached_write_nowarn(value):
26+
st.write(value)
27+
28+
29+
@st.cache
30+
def cached_widget(name):
31+
st.button(name)
32+
33+
34+
cached_write("I'm in a cached function!")
35+
cached_widget("Wadjet!")
36+
cached_write_nowarn("Me too!")

e2e/specs/st_in_cache_warning.spec.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/**
2+
* @license
3+
* Copyright 2018-2019 Streamlit Inc.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
/// <reference types="cypress" />
19+
20+
describe("st calls within cached functions", () => {
21+
beforeEach(() => {
22+
cy.visit("http://localhost:3000/");
23+
});
24+
25+
it("displays expected results", () => {
26+
// We should have two alerts
27+
cy.get(".element-container > .alert-warning").should("have.length", 2);
28+
29+
// One button
30+
cy.get(".element-container > .stButton").should("have.length", 1);
31+
32+
// And two texts
33+
cy.get(".element-container > .markdown-text-container").should(
34+
"have.length",
35+
2
36+
);
37+
});
38+
});

lib/streamlit/DeltaGenerator.py

Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
"""Allows us to create and absorb changes (aka Deltas) to elements."""
1717

1818
# Python 2/3 compatibility
19-
from __future__ import print_function, division, unicode_literals, \
20-
absolute_import
19+
from __future__ import print_function, division, unicode_literals, absolute_import
2120
from streamlit.compatibility import setup_2_3_shims
2221

2322
setup_2_3_shims(globals())
@@ -31,6 +30,7 @@
3130
from datetime import date
3231
from datetime import time
3332

33+
from streamlit import caching
3434
from streamlit import metrics
3535
from streamlit.proto import Balloons_pb2
3636
from streamlit.proto import BlockPath_pb2
@@ -104,6 +104,9 @@ def _with_element(method):
104104

105105
@_wraps_with_cleaned_sig(method, 2) # Remove self and element from sig.
106106
def wrapped_method(dg, *args, **kwargs):
107+
# Warn if we're called from within an @st.cache function
108+
caching.maybe_show_cached_st_function_warning(dg)
109+
107110
delta_type = method.__name__
108111
last_index = -1
109112

@@ -214,22 +217,24 @@ def __init__(self,
214217

215218
def __getattr__(self, name):
216219
import streamlit as st
217-
streamlit_methods = [method_name for method_name in dir(st)
218-
if callable(getattr(st, method_name))]
220+
221+
streamlit_methods = [
222+
method_name for method_name in dir(st) if callable(getattr(st, method_name))
223+
]
219224

220225
def wrapper(*args, **kwargs):
221226
if name in streamlit_methods:
222227
if self._container == BlockPath_pb2.BlockPath.SIDEBAR:
223-
message = "Method `%(name)s()` does not exist for " \
224-
"`st.sidebar`. Did you mean `st.%(name)s()`?" % {
225-
"name": name
226-
}
228+
message = (
229+
"Method `%(name)s()` does not exist for "
230+
"`st.sidebar`. Did you mean `st.%(name)s()`?" % {"name": name}
231+
)
227232
else:
228-
message = "Method `%(name)s()` does not exist for " \
229-
"`DeltaGenerator` objects. Did you mean " \
230-
"`st.%(name)s()`?" % {
231-
"name": name
232-
}
233+
message = (
234+
"Method `%(name)s()` does not exist for "
235+
"`DeltaGenerator` objects. Did you mean "
236+
"`st.%(name)s()`?" % {"name": name}
237+
)
233238
else:
234239
message = "`%(name)s()` is not a valid Streamlit command." % {
235240
"name": name
@@ -682,8 +687,7 @@ def exception(self, element, exception, exception_traceback=None):
682687
"""
683688
import streamlit.elements.exception_proto as exception_proto
684689

685-
exception_proto.marshall(element.exception, exception,
686-
exception_traceback)
690+
exception_proto.marshall(element.exception, exception, exception_traceback)
687691

688692
@_with_element
689693
def _text_exception(self, element, exception_type, message, stack_trace):
@@ -858,8 +862,7 @@ def bar_chart(self, element, data=None, width=0, height=0):
858862
altair.marshall(element.vega_lite_chart, chart, width, height=height)
859863

860864
@_with_element
861-
def vega_lite_chart(self, element, data=None, spec=None, width=0,
862-
**kwargs):
865+
def vega_lite_chart(self, element, data=None, spec=None, width=0, **kwargs):
863866
"""Display a chart using the Vega-Lite library.
864867
865868
Parameters
@@ -914,8 +917,7 @@ def vega_lite_chart(self, element, data=None, spec=None, width=0,
914917
"""
915918
import streamlit.elements.vega_lite as vega_lite
916919

917-
vega_lite.marshall(element.vega_lite_chart, data, spec, width,
918-
**kwargs)
920+
vega_lite.marshall(element.vega_lite_chart, data, spec, width, **kwargs)
919921

920922
@_with_element
921923
def altair_chart(self, element, altair_chart, width=0):
@@ -1031,8 +1033,7 @@ def graphviz_chart(self, element, figure_or_dot, width=0, height=0):
10311033

10321034
@_with_element
10331035
def plotly_chart(
1034-
self, element, figure_or_data, width=0, height=0, sharing="streamlit",
1035-
**kwargs
1036+
self, element, figure_or_data, width=0, height=0, sharing="streamlit", **kwargs
10361037
):
10371038
"""Display an interactive Plotly chart.
10381039
@@ -1111,8 +1112,7 @@ def plotly_chart(
11111112
import streamlit.elements.plotly_chart as plotly_chart
11121113

11131114
plotly_chart.marshall(
1114-
element.plotly_chart, figure_or_data, width, height, sharing,
1115-
**kwargs
1115+
element.plotly_chart, figure_or_data, width, height, sharing, **kwargs
11161116
)
11171117

11181118
@_with_element
@@ -1395,8 +1395,7 @@ def checkbox(self, element, ui_value, label, value=False):
13951395
return current_value
13961396

13971397
@_widget
1398-
def multiselect(self, element, ui_value, label, options,
1399-
format_func=str):
1398+
def multiselect(self, element, ui_value, label, options, format_func=str):
14001399
"""Display a multiselect widget.
14011400
The multiselect widget starts as empty.
14021401
@@ -1430,13 +1429,11 @@ def multiselect(self, element, ui_value, label, options,
14301429

14311430
element.multiselect.label = label
14321431
element.multiselect.default[:] = current_value
1433-
element.multiselect.options[:] = [str(format_func(opt)) for opt in
1434-
options]
1432+
element.multiselect.options[:] = [str(format_func(opt)) for opt in options]
14351433
return [options[i] for i in current_value]
14361434

14371435
@_widget
1438-
def radio(self, element, ui_value, label, options, index=0,
1439-
format_func=str):
1436+
def radio(self, element, ui_value, label, options, index=0, format_func=str):
14401437
"""Display a radio button widget.
14411438
14421439
Parameters
@@ -1470,12 +1467,10 @@ def radio(self, element, ui_value, label, options, index=0,
14701467
14711468
"""
14721469
if not isinstance(index, int):
1473-
raise TypeError(
1474-
"Radio Value has invalid type: %s" % type(index).__name__)
1470+
raise TypeError("Radio Value has invalid type: %s" % type(index).__name__)
14751471

14761472
if len(options) and not 0 <= index < len(options):
1477-
raise ValueError(
1478-
"Radio index must be between 0 and length of options")
1473+
raise ValueError("Radio index must be between 0 and length of options")
14791474

14801475
current_value = ui_value if ui_value is not None else index
14811476

@@ -1485,8 +1480,7 @@ def radio(self, element, ui_value, label, options, index=0,
14851480
return options[current_value] if len(options) else NoValue
14861481

14871482
@_widget
1488-
def selectbox(self, element, ui_value, label, options, index=0,
1489-
format_func=str):
1483+
def selectbox(self, element, ui_value, label, options, index=0, format_func=str):
14901484
"""Display a select widget.
14911485
14921486
Parameters
@@ -1522,15 +1516,13 @@ def selectbox(self, element, ui_value, label, options, index=0,
15221516
)
15231517

15241518
if len(options) and not 0 <= index < len(options):
1525-
raise ValueError(
1526-
"Selectbox index must be between 0 and length of options")
1519+
raise ValueError("Selectbox index must be between 0 and length of options")
15271520

15281521
current_value = ui_value if ui_value is not None else index
15291522

15301523
element.selectbox.label = label
15311524
element.selectbox.value = current_value
1532-
element.selectbox.options[:] = [str(format_func(opt)) for opt in
1533-
options]
1525+
element.selectbox.options[:] = [str(format_func(opt)) for opt in options]
15341526
return options[current_value] if len(options) else NoValue
15351527

15361528
@_widget
@@ -1566,7 +1558,7 @@ def slider(
15661558
Defaults to 1 if the value is an int, 0.01 otherwise.
15671559
format : str or None
15681560
Printf/Python format string.
1569-
1561+
15701562
15711563
Returns
15721564
-------
@@ -1665,8 +1657,7 @@ def slider(
16651657
else:
16661658
start, end = value
16671659
if not min_value <= start <= end <= max_value:
1668-
raise ValueError(
1669-
"The value and/or arguments are out of range.")
1660+
raise ValueError("The value and/or arguments are out of range.")
16701661

16711662
# Convert the current value to the appropriate type.
16721663
current_value = ui_value if ui_value is not None else value
@@ -1688,18 +1679,17 @@ def slider(
16881679
else:
16891680
format = "%0.2f"
16901681
# It would be great if we could guess the number of decimal places from
1691-
# the step`argument, but this would only be meaningful if step were a decimal.
1692-
# As a possible improvement we could make this function accept decimals
1682+
# the step`argument, but this would only be meaningful if step were a decimal.
1683+
# As a possible improvement we could make this function accept decimals
16931684
# and/or use some heuristics for floats.
1694-
1685+
16951686
element.slider.label = label
1696-
element.slider.value[:] = [
1697-
current_value] if single_value else current_value
1687+
element.slider.value[:] = [current_value] if single_value else current_value
16981688
element.slider.min = min_value
16991689
element.slider.max = max_value
17001690
element.slider.step = step
1701-
element.slider.format = format
1702-
1691+
element.slider.format = format
1692+
17031693
return current_value if single_value else tuple(current_value)
17041694

17051695
@_widget
@@ -1795,8 +1785,7 @@ def time_input(self, element, ui_value, label, value=None):
17951785

17961786
# Ensure that the value is either datetime/time
17971787
if not isinstance(value, datetime) and not isinstance(value, time):
1798-
raise TypeError(
1799-
"The type of the value should be either datetime or time.")
1788+
raise TypeError("The type of the value should be either datetime or time.")
18001789

18011790
# Convert datetime to time
18021791
if isinstance(value, datetime):
@@ -1842,8 +1831,7 @@ def date_input(self, element, ui_value, label, value=None):
18421831

18431832
# Ensure that the value is either datetime/time
18441833
if not isinstance(value, datetime) and not isinstance(value, date):
1845-
raise TypeError(
1846-
"The type of the value should be either datetime or date.")
1834+
raise TypeError("The type of the value should be either datetime or date.")
18471835

18481836
# Convert datetime to date
18491837
if isinstance(value, datetime):

0 commit comments

Comments
 (0)