Skip to content

Commit 79e62fc

Browse files
committed
Fix race condition in default engine selection
1 parent 3a90404 commit 79e62fc

File tree

1 file changed

+81
-105
lines changed

1 file changed

+81
-105
lines changed

src/renderer/components/Experiment/Foundation/RunModelButton.tsx

Lines changed: 81 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import InferenceEngineModal from './InferenceEngineModal';
1818
import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
1919
import OneTimePopup from 'renderer/components/Shared/OneTimePopup';
2020
import { useAPI } from 'renderer/lib/transformerlab-api-sdk';
21-
import React, { useState } from 'react';
21+
import React from 'react';
2222

2323
import { Link } from 'react-router-dom';
2424

@@ -42,6 +42,7 @@ export default function RunModelButton({
4242
const [jobId, setJobId] = useState(null);
4343
const [showRunSettings, setShowRunSettings] = useState(false);
4444
const [pipelineTag, setPipelineTag] = useState<string | null>(null);
45+
const [pipelineTagLoaded, setPipelineTagLoaded] = useState(false);
4546
const [inferenceSettings, setInferenceSettings] = useState({
4647
inferenceEngine: null,
4748
inferenceEngineFriendlyName: '',
@@ -61,52 +62,66 @@ export default function RunModelButton({
6162

6263
const archTag = experimentInfo?.config?.foundation_model_architecture ?? '';
6364

64-
const supportedEngines = React.useMemo(() => {
65-
if (!data) {
66-
return [];
67-
}
65+
// Fetch pipeline tag effect
66+
useEffect(() => {
67+
const fetchPipelineTag = async () => {
68+
if (!experimentInfo?.config?.foundation) {
69+
setPipelineTag(null);
70+
setPipelineTagLoaded(true);
71+
return;
72+
}
6873

69-
const filtered = data.filter((row) => {
70-
// Check if plugin supports the architecture
71-
const supportsArchitecture =
72-
Array.isArray(row.model_architectures) &&
73-
row.model_architectures.some(
74-
(arch) => arch.toLowerCase() === archTag.toLowerCase(),
75-
);
74+
setPipelineTagLoaded(false);
75+
try {
76+
const url = getAPIFullPath('models', ['pipeline_tag'], {
77+
modelName: experimentInfo.config.foundation,
78+
});
79+
const response = await fetch(url, { method: 'GET' });
80+
if (!response.ok) {
81+
setPipelineTag(null);
82+
} else {
83+
const data = await response.json();
84+
setPipelineTag(data?.data || null);
85+
}
86+
} catch (e) {
87+
setPipelineTag(null);
88+
console.error('Error fetching pipeline tag:', e);
89+
} finally {
90+
setPipelineTagLoaded(true);
91+
}
92+
};
7693

77-
// Check if plugin has text-to-speech support
78-
const hasTextToSpeechSupport =
79-
Array.isArray(row.supports) &&
80-
row.supports.some(
81-
(support) => support.toLowerCase() === 'text-to-speech',
82-
);
94+
fetchPipelineTag();
95+
}, [experimentInfo?.config?.foundation]);
8396

84-
// Apply filtering logic based on pipeline tag
97+
const supportedEngines = React.useMemo(() => {
98+
if (!data || !pipelineTagLoaded) return [];
99+
100+
return data.filter((row) => {
101+
const supportsArchitecture = Array.isArray(row.model_architectures) &&
102+
row.model_architectures.some(arch => arch.toLowerCase() === archTag.toLowerCase());
103+
104+
const hasTextToSpeechSupport = Array.isArray(row.supports) &&
105+
row.supports.some(support => support.toLowerCase() === 'text-to-speech');
106+
107+
if (!supportsArchitecture) return false;
108+
109+
// For text-to-speech models: must also have text-to-speech support
85110
if (pipelineTag === 'text-to-speech') {
86-
// For text-to-speech models: must support architecture AND text-to-speech
87-
return supportsArchitecture && hasTextToSpeechSupport;
88-
} else {
89-
// For non-text-to-speech models: must support architecture but NOT text-to-speech
90-
return supportsArchitecture && !hasTextToSpeechSupport;
111+
return hasTextToSpeechSupport;
91112
}
113+
114+
// For non-text-to-speech models: must NOT have text-to-speech support
115+
return !hasTextToSpeechSupport;
92116
});
93-
94-
return filtered;
95-
}, [data, archTag, pipelineTag]);
117+
}, [data, archTag, pipelineTag, pipelineTagLoaded]);
96118

97119
const unsupportedEngines = React.useMemo(() => {
98-
if (!data) {
99-
return [];
100-
}
101-
const filtered = data.filter(
102-
(row) =>
103-
!Array.isArray(row.model_architectures) ||
104-
!row.model_architectures.some(
105-
(arch) => arch.toLowerCase() === archTag.toLowerCase(),
106-
),
107-
);
108-
return filtered;
109-
}, [data, archTag]);
120+
if (!data) return [];
121+
122+
// Simply return everything that's NOT in supportedEngines
123+
return data.filter(row => !supportedEngines.some(supported => supported.uniqueId === row.uniqueId));
124+
}, [data, supportedEngines]);
110125

111126
const [isValidDiffusionModel, setIsValidDiffusionModel] = useState<
112127
boolean | null
@@ -152,54 +167,42 @@ export default function RunModelButton({
152167
};
153168
}
154169

155-
// Set a default inference Engine if there is none
156170
useEffect(() => {
171+
if (!data || !pipelineTagLoaded || supportedEngines.length === 0) return;
172+
157173
let objExperimentInfo = null;
158174
if (experimentInfo?.config?.inferenceParams) {
159-
objExperimentInfo = JSON.parse(experimentInfo?.config?.inferenceParams);
175+
try {
176+
objExperimentInfo = JSON.parse(experimentInfo?.config?.inferenceParams);
177+
} catch (e) {
178+
console.error('Failed to parse inferenceParams:', e);
179+
}
160180
}
161-
if (
162-
objExperimentInfo == null ||
163-
objExperimentInfo?.inferenceEngine == null
164-
) {
165-
// If there are supportedEngines, set the first one from supported engines as default
166-
if (supportedEngines.length > 0) {
167-
const firstEngine = supportedEngines[0];
168-
const newInferenceSettings = {
169-
inferenceEngine: firstEngine.uniqueId || null,
170-
inferenceEngineFriendlyName: firstEngine.name || '',
171-
};
172-
setInferenceSettings(newInferenceSettings);
173181

174-
// Update the experiment config with the first supported engine
175-
if (experimentInfo?.id) {
176-
fetch(
177-
chatAPI.Endpoints.Experiment.UpdateConfig(
178-
experimentInfo.id,
179-
'inferenceParams',
180-
JSON.stringify(newInferenceSettings),
181-
),
182-
).catch(() => {
183-
console.error(
184-
'Failed to update inferenceParams in experiment config',
185-
);
186-
});
187-
}
188-
} else {
189-
// This preserves the older logic where we try to get the default inference engine for a blank experiment
190-
(async () => {
191-
const { inferenceEngine, inferenceEngineFriendlyName } =
192-
await getDefaultinferenceEngines();
193-
setInferenceSettings({
194-
inferenceEngine: inferenceEngine || null,
195-
inferenceEngineFriendlyName: inferenceEngineFriendlyName || null,
196-
});
197-
})();
182+
const currentEngine = objExperimentInfo?.inferenceEngine;
183+
const currentEngineIsSupported = supportedEngines.some(engine => engine.uniqueId === currentEngine);
184+
185+
if (!currentEngine || !currentEngineIsSupported) {
186+
const firstEngine = supportedEngines[0];
187+
const newSettings = {
188+
inferenceEngine: firstEngine.uniqueId,
189+
inferenceEngineFriendlyName: firstEngine.name,
190+
};
191+
setInferenceSettings(newSettings);
192+
193+
if (experimentInfo?.id) {
194+
fetch(
195+
chatAPI.Endpoints.Experiment.UpdateConfig(
196+
experimentInfo.id,
197+
'inferenceParams',
198+
JSON.stringify(newSettings),
199+
),
200+
).catch(console.error);
198201
}
199202
} else {
200203
setInferenceSettings(objExperimentInfo);
201204
}
202-
}, [experimentInfo, supportedEngines, pipelineTag]);
205+
}, [data, pipelineTagLoaded, supportedEngines, experimentInfo?.id, experimentInfo?.config?.inferenceParams]);
203206

204207
// Check if the current foundation model is a diffusion model
205208
useEffect(() => {
@@ -228,33 +231,6 @@ export default function RunModelButton({
228231
checkValidDiffusion();
229232
}, [experimentInfo?.config?.foundation]);
230233

231-
useEffect(() => {
232-
const fetchPipelineTag = async () => {
233-
if (!experimentInfo?.config?.foundation) {
234-
setPipelineTag(null);
235-
return;
236-
}
237-
238-
try {
239-
const url = getAPIFullPath('models', ['pipeline_tag'], {
240-
modelName: experimentInfo.config.foundation,
241-
});
242-
const response = await fetch(url, { method: 'GET' });
243-
if (!response.ok) {
244-
setPipelineTag(null);
245-
} else {
246-
const data = await response.json();
247-
setPipelineTag(data?.data || null);
248-
}
249-
} catch (e) {
250-
setPipelineTag(null);
251-
console.error('Error fetching pipeline tag:', e);
252-
}
253-
};
254-
255-
fetchPipelineTag();
256-
}, [experimentInfo?.config?.foundation]);
257-
258234
function Engine() {
259235
return (
260236
<>

0 commit comments

Comments
 (0)