Skip to content

Commit c477701

Browse files
committed
FIx Mobject.replace_shader_code
1 parent d10745a commit c477701

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

manimlib/mobject/mobject.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
self.bounding_box: Vect3Array = np.zeros((3, 3))
106106
self._shaders_initialized: bool = False
107107
self._data_has_changed: bool = True
108+
self.shader_code_replacements: dict[str, str] = dict()
108109

109110
self.init_data()
110111
self._data_defaults = np.ones(1, dtype=self.data.dtype)
@@ -1895,12 +1896,12 @@ def deactivate_depth_test(self, recurse: bool = True) -> Self:
18951896

18961897
# Shader code manipulation
18971898

1899+
@affects_data
18981900
def replace_shader_code(self, old: str, new: str) -> Self:
1899-
# TODO, will this work with VMobject structure, given
1900-
# that it does not simpler return shader_wrappers of
1901-
# family?
1902-
for wrapper in self.get_shader_wrapper_list():
1903-
wrapper.replace_code(old, new)
1901+
self.shader_code_replacements[old] = new
1902+
self._shaders_initialized = False
1903+
for mob in self.get_ancestors():
1904+
mob._shaders_initialized = False
19041905
return self
19051906

19061907
def set_color_by_code(self, glsl_code: str) -> Self:
@@ -1969,6 +1970,8 @@ def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
19691970
self.shader_wrapper.vert_indices = self.get_shader_vert_indices()
19701971
self.shader_wrapper.bind_to_mobject_uniforms(self.get_uniforms())
19711972
self.shader_wrapper.depth_test = self.depth_test
1973+
for old, new in self.shader_code_replacements.items():
1974+
self.shader_wrapper.replace_code(old, new)
19721975
return self.shader_wrapper
19731976

19741977
def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:

manimlib/mobject/types/vectorized_mobject.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,10 @@ def init_shader_data(self, ctx: Context):
12921292
self.fill_shader_wrapper,
12931293
self.stroke_shader_wrapper,
12941294
]
1295+
for sw in self.shader_wrappers:
1296+
rep = self.family_members_with_points()[0]
1297+
for old, new in rep.shader_code_replacements.items():
1298+
sw.replace_code(old, new)
12951299

12961300
def refresh_shader_wrapper_id(self) -> Self:
12971301
if not self._shaders_initialized:
@@ -1355,8 +1359,9 @@ def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
13551359
self.stroke_shader_wrapper.read_in(stroke_datas),
13561360
]
13571361
for sw in shader_wrappers:
1358-
sw.bind_to_mobject_uniforms(family[0].get_uniforms())
1359-
sw.depth_test = family[0].depth_test
1362+
rep = family[0] # Representative family member
1363+
sw.bind_to_mobject_uniforms(rep.get_uniforms())
1364+
sw.depth_test = rep.depth_test
13601365
return [sw for sw in shader_wrappers if len(sw.vert_data) > 0]
13611366

13621367

0 commit comments

Comments
 (0)