Skip to content

Commit f1fc0cb

Browse files
authored
Merge pull request transformerlab#399 from transformerlab/add/supports-filtering-interact
Add filtering based on the supports field for interact in loader plugins
2 parents 0b75a1b + 5f7be1f commit f1fc0cb

File tree

2 files changed

+104
-11
lines changed

2 files changed

+104
-11
lines changed

src/renderer/components/Experiment/Interact/Interact.tsx

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,25 @@ import VisualizeGeneration from './VisualizeGeneration';
4040
import ModelLayerVisualization from './ModelLayerVisualization';
4141

4242
const fetcher = (url) => fetch(url).then((res) => res.json());
43+
// const supports = [
44+
// 'chat',
45+
// 'completion',
46+
// 'rag',
47+
// 'tools',
48+
// 'template',
49+
// 'embeddings',
50+
// 'tokenize',
51+
// 'logprobs',
52+
// 'batched',
53+
// ];
4354

4455
export default function Chat({
4556
experimentInfo,
4657
experimentInfoMutate,
4758
setRagEngine,
4859
mode,
4960
setMode,
61+
supports,
5062
}) {
5163
const { models } = chatAPI.useModelStatus();
5264
const [conversationId, setConversationId] = React.useState(null);
@@ -764,17 +776,51 @@ export default function Chat({
764776
}
765777
}
766778
>
767-
<Option value="chat">Chat</Option>
768-
<Option value="completion">Completion</Option>
769-
<Option value="visualize_model">Model Activations</Option>
770-
<Option value="model_layers">Model Architecture</Option>
771-
<Option value="rag">Query Docs (RAG)</Option>
772-
<Option value="tools">Tool Calling</Option>
773-
<Option value="template">Templated Prompt</Option>
774-
<Option value="embeddings">Embeddings</Option>
775-
<Option value="tokenize">Tokenize</Option>
776-
<Option value="logprobs">Visualize Logprobs</Option>
777-
<Option value="batched">Batched Query</Option>
779+
<Option value="chat" disabled={!supports.includes('chat')}>
780+
Chat
781+
</Option>
782+
<Option
783+
value="completion"
784+
disabled={!supports.includes('completion')}
785+
>
786+
Completion
787+
</Option>
788+
<Option
789+
value="visualize_model"
790+
disabled={!supports.includes('visualize_model')}
791+
>
792+
Model Activations
793+
</Option>
794+
<Option
795+
value="model_layers"
796+
disabled={!supports.includes('model_layers')}
797+
>
798+
Model Architecture
799+
</Option>
800+
<Option value="rag" disabled={!supports.includes('rag')}>
801+
Query Docs (RAG)
802+
</Option>
803+
<Option value="tools" disabled={!supports.includes('tools')}>
804+
Tool Calling
805+
</Option>
806+
<Option value="template" disabled={!supports.includes('template')}>
807+
Templated Prompt
808+
</Option>
809+
<Option
810+
value="embeddings"
811+
disabled={!supports.includes('embeddings')}
812+
>
813+
Embeddings
814+
</Option>
815+
<Option value="tokenize" disabled={!supports.includes('tokenize')}>
816+
Tokenize
817+
</Option>
818+
<Option value="logprobs" disabled={!supports.includes('logprobs')}>
819+
Visualize Logprobs
820+
</Option>
821+
<Option value="batched" disabled={!supports.includes('batched')}>
822+
Batched Query
823+
</Option>
778824
</Select>
779825
</FormControl>
780826
<Typography level="title-md">

src/renderer/components/MainAppPanel.tsx

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import {
1616
useEffect,
1717
useState,
1818
} from 'react';
19+
import useSWR from 'swr';
20+
1921
import { AnalyticsBrowser } from '@segment/analytics-next';
2022
import Data from './Data/Data';
2123
import Interact from './Experiment/Interact/Interact';
@@ -117,6 +119,49 @@ export default function MainAppPanel({
117119
const [selectedInteractSubpage, setSelectedInteractSubpage] =
118120
useState('chat');
119121

122+
const fetcher = (url) => fetch(url).then((res) => res.json());
123+
124+
// Extract pluginId at the top level
125+
const inferenceParams = experimentInfo?.config?.inferenceParams;
126+
const pluginId = inferenceParams
127+
? JSON.parse(inferenceParams)?.inferenceEngine
128+
: null;
129+
130+
// Use SWR at the top level, not inside useEffect
131+
const { data: modelData } = useSWR(
132+
experimentInfo?.id && pluginId
133+
? chatAPI.Endpoints.Experiment.ScriptGetFile(
134+
experimentInfo.id,
135+
pluginId,
136+
'index.json',
137+
)
138+
: null,
139+
fetcher,
140+
);
141+
142+
let modelSupports = [
143+
'chat',
144+
'completion',
145+
'rag',
146+
'tools',
147+
'template',
148+
'embeddings',
149+
'tokenize',
150+
'batched',
151+
];
152+
153+
if (modelData && modelData !== 'null' && modelData !== 'undefined') {
154+
modelSupports = JSON.parse(modelData)?.supports || [
155+
'chat',
156+
'completion',
157+
'rag',
158+
'tools',
159+
'template',
160+
'embeddings',
161+
'tokenize',
162+
'batched',
163+
];
164+
}
120165
const setFoundation = useCallback(
121166
(model) => {
122167
let model_name = '';
@@ -351,6 +396,7 @@ export default function MainAppPanel({
351396
setRagEngine={setRagEngine}
352397
mode={selectedInteractSubpage}
353398
setMode={setSelectedInteractSubpage}
399+
supports={modelSupports}
354400
/>
355401
}
356402
/>
@@ -363,6 +409,7 @@ export default function MainAppPanel({
363409
setRagEngine={setRagEngine}
364410
mode={'model_layers'}
365411
setMode={setSelectedInteractSubpage}
412+
supports={modelSupports}
366413
/>
367414
}
368415
/>

0 commit comments

Comments
 (0)