@@ -18,7 +18,7 @@ import InferenceEngineModal from './InferenceEngineModal';
1818import * as chatAPI from 'renderer/lib/transformerlab-api-sdk' ;
1919import OneTimePopup from 'renderer/components/Shared/OneTimePopup' ;
2020import { useAPI } from 'renderer/lib/transformerlab-api-sdk' ;
21- import React , { useState } from 'react' ;
21+ import React from 'react' ;
2222
2323import { Link } from 'react-router-dom' ;
2424
@@ -42,6 +42,7 @@ export default function RunModelButton({
4242 const [ jobId , setJobId ] = useState ( null ) ;
4343 const [ showRunSettings , setShowRunSettings ] = useState ( false ) ;
4444 const [ pipelineTag , setPipelineTag ] = useState < string | null > ( null ) ;
45+ const [ pipelineTagLoaded , setPipelineTagLoaded ] = useState ( false ) ;
4546 const [ inferenceSettings , setInferenceSettings ] = useState ( {
4647 inferenceEngine : null ,
4748 inferenceEngineFriendlyName : '' ,
@@ -61,52 +62,66 @@ export default function RunModelButton({
6162
6263 const archTag = experimentInfo ?. config ?. foundation_model_architecture ?? '' ;
6364
64- const supportedEngines = React . useMemo ( ( ) => {
65- if ( ! data ) {
66- return [ ] ;
67- }
65+ // Fetch pipeline tag effect
66+ useEffect ( ( ) => {
67+ const fetchPipelineTag = async ( ) => {
68+ if ( ! experimentInfo ?. config ?. foundation ) {
69+ setPipelineTag ( null ) ;
70+ setPipelineTagLoaded ( true ) ;
71+ return ;
72+ }
6873
69- const filtered = data . filter ( ( row ) => {
70- // Check if plugin supports the architecture
71- const supportsArchitecture =
72- Array . isArray ( row . model_architectures ) &&
73- row . model_architectures . some (
74- ( arch ) => arch . toLowerCase ( ) === archTag . toLowerCase ( ) ,
75- ) ;
74+ setPipelineTagLoaded ( false ) ;
75+ try {
76+ const url = getAPIFullPath ( 'models' , [ 'pipeline_tag' ] , {
77+ modelName : experimentInfo . config . foundation ,
78+ } ) ;
79+ const response = await fetch ( url , { method : 'GET' } ) ;
80+ if ( ! response . ok ) {
81+ setPipelineTag ( null ) ;
82+ } else {
83+ const data = await response . json ( ) ;
84+ setPipelineTag ( data ?. data || null ) ;
85+ }
86+ } catch ( e ) {
87+ setPipelineTag ( null ) ;
88+ console . error ( 'Error fetching pipeline tag:' , e ) ;
89+ } finally {
90+ setPipelineTagLoaded ( true ) ;
91+ }
92+ } ;
7693
77- // Check if plugin has text-to-speech support
78- const hasTextToSpeechSupport =
79- Array . isArray ( row . supports ) &&
80- row . supports . some (
81- ( support ) => support . toLowerCase ( ) === 'text-to-speech' ,
82- ) ;
94+ fetchPipelineTag ( ) ;
95+ } , [ experimentInfo ?. config ?. foundation ] ) ;
8396
84- // Apply filtering logic based on pipeline tag
97+ const supportedEngines = React . useMemo ( ( ) => {
98+ if ( ! data || ! pipelineTagLoaded ) return [ ] ;
99+
100+ return data . filter ( ( row ) => {
101+ const supportsArchitecture = Array . isArray ( row . model_architectures ) &&
102+ row . model_architectures . some ( arch => arch . toLowerCase ( ) === archTag . toLowerCase ( ) ) ;
103+
104+ const hasTextToSpeechSupport = Array . isArray ( row . supports ) &&
105+ row . supports . some ( support => support . toLowerCase ( ) === 'text-to-speech' ) ;
106+
107+ if ( ! supportsArchitecture ) return false ;
108+
109+ // For text-to-speech models: must also have text-to-speech support
85110 if ( pipelineTag === 'text-to-speech' ) {
86- // For text-to-speech models: must support architecture AND text-to-speech
87- return supportsArchitecture && hasTextToSpeechSupport ;
88- } else {
89- // For non-text-to-speech models: must support architecture but NOT text-to-speech
90- return supportsArchitecture && ! hasTextToSpeechSupport ;
111+ return hasTextToSpeechSupport ;
91112 }
113+
114+ // For non-text-to-speech models: must NOT have text-to-speech support
115+ return ! hasTextToSpeechSupport ;
92116 } ) ;
93-
94- return filtered ;
95- } , [ data , archTag , pipelineTag ] ) ;
117+ } , [ data , archTag , pipelineTag , pipelineTagLoaded ] ) ;
96118
97119 const unsupportedEngines = React . useMemo ( ( ) => {
98- if ( ! data ) {
99- return [ ] ;
100- }
101- const filtered = data . filter (
102- ( row ) =>
103- ! Array . isArray ( row . model_architectures ) ||
104- ! row . model_architectures . some (
105- ( arch ) => arch . toLowerCase ( ) === archTag . toLowerCase ( ) ,
106- ) ,
107- ) ;
108- return filtered ;
109- } , [ data , archTag ] ) ;
120+ if ( ! data ) return [ ] ;
121+
122+ // Simply return everything that's NOT in supportedEngines
123+ return data . filter ( row => ! supportedEngines . some ( supported => supported . uniqueId === row . uniqueId ) ) ;
124+ } , [ data , supportedEngines ] ) ;
110125
111126 const [ isValidDiffusionModel , setIsValidDiffusionModel ] = useState <
112127 boolean | null
@@ -152,54 +167,42 @@ export default function RunModelButton({
152167 } ;
153168 }
154169
155- // Set a default inference Engine if there is none
156170 useEffect ( ( ) => {
171+ if ( ! data || ! pipelineTagLoaded || supportedEngines . length === 0 ) return ;
172+
157173 let objExperimentInfo = null ;
158174 if ( experimentInfo ?. config ?. inferenceParams ) {
159- objExperimentInfo = JSON . parse ( experimentInfo ?. config ?. inferenceParams ) ;
175+ try {
176+ objExperimentInfo = JSON . parse ( experimentInfo ?. config ?. inferenceParams ) ;
177+ } catch ( e ) {
178+ console . error ( 'Failed to parse inferenceParams:' , e ) ;
179+ }
160180 }
161- if (
162- objExperimentInfo == null ||
163- objExperimentInfo ?. inferenceEngine == null
164- ) {
165- // If there are supportedEngines, set the first one from supported engines as default
166- if ( supportedEngines . length > 0 ) {
167- const firstEngine = supportedEngines [ 0 ] ;
168- const newInferenceSettings = {
169- inferenceEngine : firstEngine . uniqueId || null ,
170- inferenceEngineFriendlyName : firstEngine . name || '' ,
171- } ;
172- setInferenceSettings ( newInferenceSettings ) ;
173181
174- // Update the experiment config with the first supported engine
175- if ( experimentInfo ?. id ) {
176- fetch (
177- chatAPI . Endpoints . Experiment . UpdateConfig (
178- experimentInfo . id ,
179- 'inferenceParams' ,
180- JSON . stringify ( newInferenceSettings ) ,
181- ) ,
182- ) . catch ( ( ) => {
183- console . error (
184- 'Failed to update inferenceParams in experiment config' ,
185- ) ;
186- } ) ;
187- }
188- } else {
189- // This preserves the older logic where we try to get the default inference engine for a blank experiment
190- ( async ( ) => {
191- const { inferenceEngine, inferenceEngineFriendlyName } =
192- await getDefaultinferenceEngines ( ) ;
193- setInferenceSettings ( {
194- inferenceEngine : inferenceEngine || null ,
195- inferenceEngineFriendlyName : inferenceEngineFriendlyName || null ,
196- } ) ;
197- } ) ( ) ;
182+ const currentEngine = objExperimentInfo ?. inferenceEngine ;
183+ const currentEngineIsSupported = supportedEngines . some ( engine => engine . uniqueId === currentEngine ) ;
184+
185+ if ( ! currentEngine || ! currentEngineIsSupported ) {
186+ const firstEngine = supportedEngines [ 0 ] ;
187+ const newSettings = {
188+ inferenceEngine : firstEngine . uniqueId ,
189+ inferenceEngineFriendlyName : firstEngine . name ,
190+ } ;
191+ setInferenceSettings ( newSettings ) ;
192+
193+ if ( experimentInfo ?. id ) {
194+ fetch (
195+ chatAPI . Endpoints . Experiment . UpdateConfig (
196+ experimentInfo . id ,
197+ 'inferenceParams' ,
198+ JSON . stringify ( newSettings ) ,
199+ ) ,
200+ ) . catch ( console . error ) ;
198201 }
199202 } else {
200203 setInferenceSettings ( objExperimentInfo ) ;
201204 }
202- } , [ experimentInfo , supportedEngines , pipelineTag ] ) ;
205+ } , [ data , pipelineTagLoaded , supportedEngines , experimentInfo ?. id , experimentInfo ?. config ?. inferenceParams ] ) ;
203206
204207 // Check if the current foundation model is a diffusion model
205208 useEffect ( ( ) => {
@@ -228,33 +231,6 @@ export default function RunModelButton({
228231 checkValidDiffusion ( ) ;
229232 } , [ experimentInfo ?. config ?. foundation ] ) ;
230233
231- useEffect ( ( ) => {
232- const fetchPipelineTag = async ( ) => {
233- if ( ! experimentInfo ?. config ?. foundation ) {
234- setPipelineTag ( null ) ;
235- return ;
236- }
237-
238- try {
239- const url = getAPIFullPath ( 'models' , [ 'pipeline_tag' ] , {
240- modelName : experimentInfo . config . foundation ,
241- } ) ;
242- const response = await fetch ( url , { method : 'GET' } ) ;
243- if ( ! response . ok ) {
244- setPipelineTag ( null ) ;
245- } else {
246- const data = await response . json ( ) ;
247- setPipelineTag ( data ?. data || null ) ;
248- }
249- } catch ( e ) {
250- setPipelineTag ( null ) ;
251- console . error ( 'Error fetching pipeline tag:' , e ) ;
252- }
253- } ;
254-
255- fetchPipelineTag ( ) ;
256- } , [ experimentInfo ?. config ?. foundation ] ) ;
257-
258234 function Engine ( ) {
259235 return (
260236 < >
0 commit comments