Skip to content

Commit d88868a

Browse files
authored
Merge pull request transformerlab#711 from transformerlab/add/checkpointing-ui
Add/checkpointing UI
2 parents db680d5 + 9925856 commit d88868a

File tree

3 files changed

+134
-2
lines changed

3 files changed

+134
-2
lines changed

src/renderer/components/Experiment/Train/TrainLoRA.tsx

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import {
3434
StopCircleIcon,
3535
Trash2Icon,
3636
UploadIcon,
37+
WaypointsIcon,
3738
} from 'lucide-react';
3839

3940
import dayjs from 'dayjs';
@@ -51,6 +52,7 @@ import TensorboardModal from './TensorboardModal';
5152
import ViewOutputModal from './ViewOutputModal';
5253
import ViewEvalImagesModal from './ViewEvalImagesModal';
5354
import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext';
55+
import ViewCheckpointsModal from './ViewCheckpointsModal';
5456
dayjs.extend(relativeTime);
5557
var duration = require('dayjs/plugin/duration');
5658
dayjs.extend(duration);
@@ -106,6 +108,7 @@ export default function TrainLoRA({}) {
106108
const [viewEvalImagesFromJob, setViewEvalImagesFromJob] = useState(-1);
107109
const [templateID, setTemplateID] = useState('-1');
108110
const [currentPlugin, setCurrentPlugin] = useState('');
111+
const [viewCheckpointsFromJob, setViewCheckpointsFromJob] = useState(-1);
109112

110113
const { data, error, isLoading, mutate } = useSWR(
111114
chatAPI.Endpoints.Tasks.ListByTypeInExperiment('TRAIN', experimentInfo?.id),
@@ -190,12 +193,16 @@ export default function TrainLoRA({}) {
190193
sweeps={viewOutputFromSweepJob}
191194
setsweepJob={setViewOutputFromSweepJob}
192195
/>
193-
194196
<ViewEvalImagesModal
195197
open={viewEvalImagesFromJob !== -1}
196198
onClose={() => setViewEvalImagesFromJob(-1)}
197199
jobId={viewEvalImagesFromJob}
198200
/>
201+
<ViewCheckpointsModal
202+
open={viewCheckpointsFromJob !== -1}
203+
onClose={() => setViewCheckpointsFromJob(-1)}
204+
jobId={viewCheckpointsFromJob}
205+
/>
199206
<Sheet
200207
sx={{
201208
display: 'flex',
@@ -523,6 +530,18 @@ export default function TrainLoRA({}) {
523530
Sweep Output
524531
</Button>
525532
)}
533+
{job?.job_data?.checkpoints && (
534+
<Button
535+
size="sm"
536+
variant="plain"
537+
onClick={() => {
538+
setViewCheckpointsFromJob(job?.id);
539+
}}
540+
startDecorator={<WaypointsIcon />}
541+
>
542+
Checkpoints
543+
</Button>
544+
)}
526545
<IconButton variant="plain">
527546
<Trash2Icon
528547
onClick={async () => {
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import {
2+
Modal,
3+
ModalDialog,
4+
Typography,
5+
ModalClose,
6+
Table,
7+
Button,
8+
Box,
9+
} from '@mui/joy';
10+
import { PlayIcon } from 'lucide-react';
11+
import { useAPI } from 'renderer/lib/transformerlab-api-sdk';
12+
import { formatBytes } from 'renderer/lib/utils';
13+
14+
export default function ViewCheckpointsModal({ open, onClose, jobId }) {
15+
const { data, isLoading: checkpointsLoading } = useAPI(
16+
'jobs',
17+
['getCheckpoints'],
18+
{ jobId },
19+
);
20+
21+
const handleRestartFromCheckpoint = (checkpoint) => {
22+
// TODO: Implement restart functionality
23+
console.log('Restarting from checkpoint:', checkpoint);
24+
};
25+
26+
let noCheckpoints = false;
27+
28+
if (!checkpointsLoading && data?.checkpoints?.length === 0) {
29+
noCheckpoints = true;
30+
}
31+
32+
return (
33+
<Modal open={open} onClose={() => onClose()}>
34+
<ModalDialog sx={{ minWidth: '80%' }}>
35+
<ModalClose />
36+
37+
{noCheckpoints ? (
38+
<Typography level="body-md" sx={{ textAlign: 'center', py: 4 }}>
39+
No checkpoints were saved in this job.
40+
</Typography>
41+
) : (
42+
<>
43+
<Typography level="h4" component="h2">
44+
Checkpoints for Job {jobId}
45+
</Typography>
46+
47+
{!checkpointsLoading && data && (
48+
<Box sx={{ mb: 2 }}>
49+
<Typography level="body-md">
50+
<strong>Model:</strong> {data.model_name}
51+
</Typography>
52+
<Typography level="body-md">
53+
<strong>Adaptor:</strong> {data.adaptor_name}
54+
</Typography>
55+
</Box>
56+
)}
57+
58+
{checkpointsLoading ? (
59+
<Typography level="body-md">Loading checkpoints...</Typography>
60+
) : (
61+
<Box sx={{ maxHeight: 400, overflow: 'auto' }}>
62+
<Table>
63+
<thead>
64+
<tr>
65+
<th width="50px">#</th>
66+
<th>Checkpoint</th>
67+
<th>Date</th>
68+
<th width="100px">Size</th>
69+
<th style={{ textAlign: 'right' }}>&nbsp;</th>
70+
</tr>
71+
</thead>
72+
<tbody>
73+
{data?.checkpoints?.map((checkpoint, index) => (
74+
<tr key={index}>
75+
<td>
76+
<Typography level="body-sm">
77+
{data?.checkpoints?.length - index}.
78+
</Typography>
79+
</td>
80+
<td>
81+
<Typography level="title-sm">
82+
{checkpoint.filename}
83+
</Typography>
84+
</td>
85+
<td>{new Date(checkpoint.date).toLocaleString()}</td>
86+
<td>{formatBytes(checkpoint.size)}</td>
87+
<td style={{ textAlign: 'right' }}>
88+
{/* <Button
89+
size="sm"
90+
variant="outlined"
91+
onClick={() =>
92+
handleRestartFromCheckpoint(checkpoint.filename)
93+
}
94+
startDecorator={<PlayIcon />}
95+
>
96+
Restart training from here
97+
</Button> */}
98+
</td>
99+
</tr>
100+
))}
101+
</tbody>
102+
</Table>
103+
</Box>
104+
)}
105+
</>
106+
)}
107+
</ModalDialog>
108+
</Modal>
109+
);
110+
}

src/renderer/lib/api-client/allEndpoints.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@
8787
"stop": {
8888
"method": "GET",
8989
"path": "jobs/{jobId}/stop"
90+
},
91+
"getCheckpoints": {
92+
"method": "GET",
93+
"path": "jobs/{jobId}/checkpoints"
9094
}
9195
},
9296
"datasets": {
@@ -234,6 +238,5 @@
234238
"method": "GET",
235239
"path": "train/job/{job_id}/sweep_results"
236240
}
237-
238241
}
239242
}

0 commit comments

Comments
 (0)