Skip to content

Commit 876df7b

Browse files
authored
Merge pull request transformerlab#366 from transformerlab/add/openai-gpt4.1
Add support for GPT 4.1 models
2 parents 7d810e5 + 5d67448 commit 876df7b

File tree

1 file changed

+143
-133
lines changed

1 file changed

+143
-133
lines changed

src/renderer/components/Experiment/Widgets/ModelProviderWidget.tsx

Lines changed: 143 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -3,156 +3,166 @@ import useSWR from 'swr';
33
import * as chatAPI from 'renderer/lib/transformerlab-api-sdk';
44
import { Autocomplete } from '@mui/joy';
55
import {
6-
WidgetProps,
7-
RJSFSchema,
8-
StrictRJSFSchema,
9-
FormContextType,
6+
WidgetProps,
7+
RJSFSchema,
8+
StrictRJSFSchema,
9+
FormContextType,
1010
} from '@rjsf/utils';
1111

1212
// Simple fetcher for useSWR.
1313
const fetcher = (url: string) => fetch(url).then((res) => res.json());
1414

1515
function ModelProviderWidget<
16-
T = any,
17-
S extends StrictRJSFSchema = RJSFSchema,
18-
F extends FormContextType = any
16+
T = any,
17+
S extends StrictRJSFSchema = RJSFSchema,
18+
F extends FormContextType = any,
1919
>(props: WidgetProps<T, S, F>) {
20-
const {
21-
id,
22-
value,
23-
required,
24-
disabled,
25-
readonly,
26-
autofocus,
27-
onChange,
28-
options,
29-
schema,
30-
multiple,
31-
} = props;
20+
const {
21+
id,
22+
value,
23+
required,
24+
disabled,
25+
readonly,
26+
autofocus,
27+
onChange,
28+
options,
29+
schema,
30+
multiple,
31+
} = props;
3232

33-
// Determine multiple, defaulting to true.
34-
const _multiple =
35-
typeof multiple !== 'undefined'
36-
? Boolean(multiple)
37-
: typeof options.multiple !== 'undefined'
38-
? Boolean(options.multiple)
39-
: true;
33+
// Determine multiple, defaulting to true.
34+
const _multiple =
35+
typeof multiple !== 'undefined'
36+
? Boolean(multiple)
37+
: typeof options.multiple !== 'undefined'
38+
? Boolean(options.multiple)
39+
: true;
4040

41-
// Disabled API key mapping.
42-
const isDisabledFilter = true;
43-
const disabledEnvMap = {
44-
claude: 'ANTHROPIC_API_KEY',
45-
azure: 'AZURE_OPENAI_DETAILS',
46-
openai: 'OPENAI_API_KEY',
47-
custom: 'CUSTOM_MODEL_API_KEY',
48-
};
49-
const configKeysInOrder = Object.values(disabledEnvMap);
50-
const configResults = configKeysInOrder.map((key) =>
51-
useSWR(chatAPI.Endpoints.Config.Get(key), fetcher)
52-
);
53-
const configValues = React.useMemo(() => {
54-
const values: Record<string, any> = {};
55-
configKeysInOrder.forEach((key, idx) => {
56-
values[key] = configResults[idx]?.data;
57-
});
58-
return values;
59-
}, [configKeysInOrder, configResults]);
41+
// Disabled API key mapping.
42+
const isDisabledFilter = true;
43+
const disabledEnvMap = {
44+
claude: 'ANTHROPIC_API_KEY',
45+
azure: 'AZURE_OPENAI_DETAILS',
46+
openai: 'OPENAI_API_KEY',
47+
custom: 'CUSTOM_MODEL_API_KEY',
48+
};
49+
const configKeysInOrder = Object.values(disabledEnvMap);
50+
const configResults = configKeysInOrder.map((key) =>
51+
useSWR(chatAPI.Endpoints.Config.Get(key), fetcher),
52+
);
53+
const configValues = React.useMemo(() => {
54+
const values: Record<string, any> = {};
55+
configKeysInOrder.forEach((key, idx) => {
56+
values[key] = configResults[idx]?.data;
57+
});
58+
return values;
59+
}, [configKeysInOrder, configResults]);
6060

61-
// Map: label => stored value.
62-
const labelToCustomValue: Record<string, string> = {
63-
'Claude 3.7 Sonnet': 'claude-3-7-sonnet-latest',
64-
'Claude 3.5 Haiku': 'claude-3-5-haiku-latest',
65-
'OpenAI GPT 4o': 'gpt-4o',
66-
'OpenAI GPT 4o Mini': 'gpt-4o-mini',
67-
'Azure OpenAI': 'azure-openai',
68-
'Custom Model API': 'custom-model-api',
69-
'Local': 'local',
70-
};
61+
// Map: label => stored value.
62+
const labelToCustomValue: Record<string, string> = {
63+
'Claude 3.7 Sonnet': 'claude-3-7-sonnet-latest',
64+
'Claude 3.5 Haiku': 'claude-3-5-haiku-latest',
65+
'OpenAI GPT 4o': 'gpt-4o',
66+
'OpenAI GPT 4.1': 'gpt-4.1',
67+
'OpenAI GPT 4o Mini': 'gpt-4o-mini',
68+
'OpenAI GPT 4.1 Mini': 'gpt-4.1-mini',
69+
'OpenAI GPT 4.1 Nano': 'gpt-4.1-nano',
70+
'Azure OpenAI': 'azure-openai',
71+
'Custom Model API': 'custom-model-api',
72+
Local: 'local',
73+
};
7174

72-
// Options coming from mapping keys.
73-
const optionsList = Object.keys(labelToCustomValue);
75+
// Options coming from mapping keys.
76+
const optionsList = Object.keys(labelToCustomValue);
7477

75-
// Inverse mapping: stored value => label.
76-
const customValueToLabel = Object.entries(labelToCustomValue).reduce(
77-
(acc, [label, custom]) => {
78-
acc[custom] = label;
79-
return acc;
80-
},
81-
{} as Record<string, string>
82-
);
78+
// Inverse mapping: stored value => label.
79+
const customValueToLabel = Object.entries(labelToCustomValue).reduce(
80+
(acc, [label, custom]) => {
81+
acc[custom] = label;
82+
return acc;
83+
},
84+
{} as Record<string, string>,
85+
);
8386

84-
// Set default/current value.
85-
const defaultValue = _multiple ? [] : '';
86-
const currentValue = value !== undefined ? value : defaultValue;
87+
// Set default/current value.
88+
const defaultValue = _multiple ? [] : '';
89+
const currentValue = value !== undefined ? value : defaultValue;
8790

88-
// Convert stored value(s) to display labels.
89-
const displayValue = _multiple
90-
? Array.isArray(currentValue)
91-
? currentValue.map((val) => customValueToLabel[val] || val)
92-
: []
93-
: customValueToLabel[currentValue] || currentValue;
91+
// Convert stored value(s) to display labels.
92+
const displayValue = _multiple
93+
? Array.isArray(currentValue)
94+
? currentValue.map((val) => customValueToLabel[val] || val)
95+
: []
96+
: customValueToLabel[currentValue] || currentValue;
9497

95-
// Build disabled mapping for options.
96-
const combinedOptions = optionsList.reduce(
97-
(acc: Record<string, { disabled: boolean; info?: string }>, opt) => {
98-
const lower = opt.toLowerCase();
99-
let optDisabled = false;
100-
let infoMessage = '';
101-
if (isDisabledFilter) {
102-
for (const envKey in disabledEnvMap) {
103-
if (lower.startsWith(envKey)) {
104-
const configKey = disabledEnvMap[envKey];
105-
const configVal = configValues[configKey];
106-
optDisabled = !configVal || configVal === '';
107-
if (optDisabled) {
108-
infoMessage = `Please set ${configKey} in settings`;
109-
}
110-
break;
111-
}
112-
}
98+
// Build disabled mapping for options.
99+
const combinedOptions = optionsList.reduce(
100+
(acc: Record<string, { disabled: boolean; info?: string }>, opt) => {
101+
const lower = opt.toLowerCase();
102+
let optDisabled = false;
103+
let infoMessage = '';
104+
if (isDisabledFilter) {
105+
for (const envKey in disabledEnvMap) {
106+
if (lower.startsWith(envKey)) {
107+
const configKey = disabledEnvMap[envKey];
108+
const configVal = configValues[configKey];
109+
optDisabled = !configVal || configVal === '';
110+
if (optDisabled) {
111+
infoMessage = `Please set ${configKey} in settings`;
113112
}
114-
acc[opt] = { disabled: optDisabled, info: optDisabled ? infoMessage : '' };
115-
return acc;
116-
},
117-
{}
118-
);
113+
break;
114+
}
115+
}
116+
}
117+
acc[opt] = {
118+
disabled: optDisabled,
119+
info: optDisabled ? infoMessage : '',
120+
};
121+
return acc;
122+
},
123+
{},
124+
);
119125

120-
return (
121-
<>
122-
<Autocomplete
123-
multiple={_multiple}
124-
id={id}
125-
placeholder={schema.title || ''}
126-
options={optionsList}
127-
getOptionLabel={(option) =>
128-
option +
129-
(combinedOptions[option]?.disabled ? ' - ' + combinedOptions[option].info : '')
130-
}
131-
getOptionDisabled={(option) => combinedOptions[option]?.disabled ?? false}
132-
value={displayValue}
133-
onChange={(event, newValue) => {
134-
const storedValue = _multiple
135-
? newValue.map((label) => labelToCustomValue[label] || label)
136-
: (labelToCustomValue[newValue] || newValue);
137-
onChange(storedValue);
138-
}}
139-
disabled={disabled || readonly}
140-
autoFocus={autofocus}
141-
/>
142-
{/* Hidden input to capture the stored value on form submission */}
143-
<input
144-
type="hidden"
145-
name={id}
146-
value={
147-
_multiple
148-
? Array.isArray(currentValue)
149-
? currentValue.join(',')
150-
: currentValue
151-
: currentValue
152-
}
153-
/>
154-
</>
155-
);
126+
return (
127+
<>
128+
<Autocomplete
129+
multiple={_multiple}
130+
id={id}
131+
placeholder={schema.title || ''}
132+
options={optionsList}
133+
getOptionLabel={(option) =>
134+
option +
135+
(combinedOptions[option]?.disabled
136+
? ' - ' + combinedOptions[option].info
137+
: '')
138+
}
139+
getOptionDisabled={(option) =>
140+
combinedOptions[option]?.disabled ?? false
141+
}
142+
value={displayValue}
143+
onChange={(event, newValue) => {
144+
const storedValue = _multiple
145+
? newValue.map((label) => labelToCustomValue[label] || label)
146+
: labelToCustomValue[newValue] || newValue;
147+
onChange(storedValue);
148+
}}
149+
disabled={disabled || readonly}
150+
autoFocus={autofocus}
151+
/>
152+
{/* Hidden input to capture the stored value on form submission */}
153+
<input
154+
type="hidden"
155+
name={id}
156+
value={
157+
_multiple
158+
? Array.isArray(currentValue)
159+
? currentValue.join(',')
160+
: currentValue
161+
: currentValue
162+
}
163+
/>
164+
</>
165+
);
156166
}
157167

158168
export default ModelProviderWidget;

0 commit comments

Comments
 (0)