Skip to content

Commit 7d2bcf1

Browse files
authored
Merge pull request transformerlab#770 from transformerlab/fix/model-groups-search
Model groups search bar now looks at more things than just model groups name
2 parents d2ae760 + fe10840 commit 7d2bcf1

File tree

1 file changed

+115
-73
lines changed

1 file changed

+115
-73
lines changed

src/renderer/components/ModelZoo/ModelGroups.tsx

Lines changed: 115 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ export default function ModelGroups({ experimentInfo }) {
178178
const [currentlyDownloading, setCurrentlyDownloading] = useState(null);
179179
const [canceling, setCanceling] = useState(false);
180180
const [groupSearchText, setGroupSearchText] = useState('');
181+
const [showFilters, setShowFilters] = useState(false);
181182

182183
const {
183184
data: groupData,
@@ -223,7 +224,7 @@ export default function ModelGroups({ experimentInfo }) {
223224

224225
useEffect(() => {
225226
if (selectedGroup) {
226-
setFilters({ archived: false, license: 'All', architecture: 'All' });
227+
setFilters({ archived: false, architecture: 'All' });
227228
}
228229
}, [selectedGroup]);
229230

@@ -240,17 +241,17 @@ export default function ModelGroups({ experimentInfo }) {
240241
}
241242
}, [groupData, selectedGroup]);
242243

243-
const getLicenseOptions = (models) => {
244-
const lowercaseSet = new Set();
245-
models?.forEach((m) => {
244+
const getLicenseOptions = (models: any[]): string[] => {
245+
const lowercaseSet = new Set<string>();
246+
models?.forEach((m: any) => {
246247
if (m.license) lowercaseSet.add(m.license.toLowerCase());
247248
});
248249
return Array.from(lowercaseSet).sort();
249250
};
250251

251-
const getArchitectureOptions = (models) => {
252-
const lowercaseSet = new Set();
253-
models?.forEach((m) => {
252+
const getArchitectureOptions = (models: any[]): string[] => {
253+
const lowercaseSet = new Set<string>();
254+
models?.forEach((m: any) => {
254255
if (m.architecture) lowercaseSet.add(m.architecture.toLowerCase());
255256
});
256257
return Array.from(lowercaseSet).sort();
@@ -259,15 +260,15 @@ export default function ModelGroups({ experimentInfo }) {
259260
const licenseOptions = selectedGroup
260261
? getLicenseOptions(selectedGroup.models)
261262
: [];
262-
const archOptions = selectedGroup
263+
const archOptions: string[] = selectedGroup
263264
? getArchitectureOptions(selectedGroup.models)
264265
: [];
265266

266267
if (isLoading) return <ModelGroupsSkeleton />;
267268
if (error) return <Typography>Error loading model groups.</Typography>;
268269
if (!groupData || !selectedGroup) return null;
269270

270-
const handleSortClick = (column) => {
271+
const handleSortClick = (column: string) => {
271272
const isAsc = orderBy === column && order === 'asc';
272273
setOrder(isAsc ? 'desc' : 'asc');
273274
setOrderBy(column);
@@ -357,7 +358,7 @@ export default function ModelGroups({ experimentInfo }) {
357358
}}
358359
>
359360
<Input
360-
placeholder="Search groups"
361+
placeholder="Search"
361362
value={groupSearchText}
362363
onChange={(e) => setGroupSearchText(e.target.value)}
363364
startDecorator={<SearchIcon />}
@@ -374,11 +375,38 @@ export default function ModelGroups({ experimentInfo }) {
374375
}}
375376
>
376377
{[...groupData]
377-
.filter((group) =>
378-
group.name
379-
.toLowerCase()
380-
.includes(groupSearchText.toLowerCase()),
381-
)
378+
.filter((group) => {
379+
if (!groupSearchText) return true;
380+
381+
const searchLower = groupSearchText.toLowerCase();
382+
383+
// Search in group properties
384+
const groupFields = [
385+
group.name,
386+
group.description,
387+
...(group.tags || []),
388+
];
389+
390+
// Search in model properties within the group
391+
const modelFields =
392+
group.models?.flatMap((model: any) => [
393+
model.name,
394+
model.architecture,
395+
model.license,
396+
model.description,
397+
model.huggingface_repo,
398+
model.id,
399+
...(model.tags || []),
400+
]) || [];
401+
402+
const allSearchableFields = [...groupFields, ...modelFields];
403+
404+
return allSearchableFields.some(
405+
(field) =>
406+
field &&
407+
field.toString().toLowerCase().includes(searchLower),
408+
);
409+
})
382410
.sort((a, b) => a.name.localeCompare(b.name))
383411
.map((group) => {
384412
const isSelected = selectedGroup?.name === group.name;
@@ -486,72 +514,86 @@ export default function ModelGroups({ experimentInfo }) {
486514
sx={{
487515
display: 'flex',
488516
alignItems: 'center',
489-
gap: 1,
517+
justifyContent: 'space-between',
490518
flexWrap: 'wrap',
491519
mb: 1,
492520
}}
493521
>
494-
<Typography level="h4">
495-
{selectedGroup.name.charAt(0).toUpperCase() +
496-
selectedGroup.name.slice(1)}
497-
</Typography>
498-
{selectedGroup.tags?.map((tag) => (
499-
<Chip
500-
key={tag}
501-
size="sm"
502-
variant="outlined"
503-
sx={{
504-
fontSize: '0.7rem',
505-
variant: 'soft',
506-
color: 'info',
507-
}}
508-
>
509-
{tag}
510-
</Chip>
511-
))}
522+
<Box
523+
sx={{
524+
display: 'flex',
525+
alignItems: 'center',
526+
gap: 1,
527+
flexWrap: 'wrap',
528+
}}
529+
>
530+
<Typography level="h4">
531+
{selectedGroup.name.charAt(0).toUpperCase() +
532+
selectedGroup.name.slice(1)}
533+
</Typography>
534+
{selectedGroup.tags?.map((tag: string) => (
535+
<Chip
536+
key={tag}
537+
size="sm"
538+
variant="outlined"
539+
sx={{
540+
fontSize: '0.7rem',
541+
variant: 'soft',
542+
color: 'info',
543+
}}
544+
>
545+
{tag}
546+
</Chip>
547+
))}
548+
</Box>
549+
<Button
550+
size="sm"
551+
variant="outlined"
552+
onClick={() => setShowFilters(!showFilters)}
553+
>
554+
Filters {showFilters ? '▲' : '▼'}
555+
</Button>
512556
</Box>
513557
{/* <Typography level="body-md" sx={{ mb: 2 }}>
514558
{selectedGroup.description}
515559
</Typography> */}
516-
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: 1.5 }}>
517-
<FormControl sx={{ flex: 1 }} size="sm">
518-
<FormLabel>&nbsp;</FormLabel>
519-
<Input
520-
placeholder="Search"
521-
value={searchText}
522-
onChange={(e) => setSearchText(e.target.value)}
523-
startDecorator={<SearchIcon />}
524-
/>
525-
</FormControl>
526-
<FormControl size="sm">
527-
<FormLabel>Status</FormLabel>
528-
<Select
529-
value={filters?.archived}
530-
onChange={(e, newValue) =>
531-
setFilters({ ...filters, archived: newValue })
532-
}
533-
>
534-
<Option value={false}>Hide Archived</Option>
535-
<Option value="All">Show Archived</Option>
536-
</Select>
537-
</FormControl>
538-
<FormControl size="sm">
539-
<FormLabel>Architecture</FormLabel>
540-
<Select
541-
value={filters?.architecture}
542-
onChange={(e, newValue) =>
543-
setFilters({ ...filters, architecture: newValue })
544-
}
545-
>
546-
<Option value="All">All</Option>
547-
{archOptions.map((type) => (
548-
<Option key={type} value={type}>
549-
{type}
550-
</Option>
551-
))}
552-
</Select>
553-
</FormControl>
554-
</Box>
560+
{showFilters && (
561+
<Box
562+
sx={{ display: 'flex', flexWrap: 'wrap', gap: 1.5, mt: 1 }}
563+
>
564+
<FormControl size="sm">
565+
<FormLabel>Status</FormLabel>
566+
<Select
567+
value={filters?.archived}
568+
onChange={(e, newValue) =>
569+
setFilters({ ...filters, archived: newValue ?? false })
570+
}
571+
>
572+
<Option value={false}>Hide Archived</Option>
573+
<Option value="All">Show Archived</Option>
574+
</Select>
575+
</FormControl>
576+
<FormControl size="sm">
577+
<FormLabel>Architecture</FormLabel>
578+
<Select
579+
value={filters?.architecture}
580+
onChange={(e, newValue) =>
581+
setFilters({
582+
...filters,
583+
architecture: newValue ?? 'All',
584+
})
585+
}
586+
>
587+
<Option value="All">All</Option>
588+
{archOptions.map((type: string) => (
589+
<Option key={type} value={type}>
590+
{type}
591+
</Option>
592+
))}
593+
</Select>
594+
</FormControl>
595+
</Box>
596+
)}
555597
</Sheet>
556598

557599
<Box

0 commit comments

Comments
 (0)