|
| 1 | +from collections import defaultdict |
| 2 | + |
| 3 | +import pickle |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from torch.utils.data import Dataset |
| 8 | + |
| 9 | + |
| 10 | +class CancerPatients(Dataset): |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + data, |
| 14 | + scaling_data, |
| 15 | + chemo_coeff, |
| 16 | + radio_coeff, |
| 17 | + num_time_steps, |
| 18 | + window_size, |
| 19 | + factuals=None, |
| 20 | + transform=None, |
| 21 | + ): |
| 22 | + self._data = defaultdict(dict) |
| 23 | + |
| 24 | + self._process_input_data( |
| 25 | + input_data=data, scaling_data=scaling_data, |
| 26 | + ) |
| 27 | + |
| 28 | + self._return_factual_data = False |
| 29 | + |
| 30 | + if factuals is not None: |
| 31 | + self._process_input_data( |
| 32 | + input_data=factuals, scaling_data=scaling_data, data_key="factuals", |
| 33 | + ) |
| 34 | + self._return_factual_data = True |
| 35 | + |
| 36 | + self._chemo_coeff = chemo_coeff |
| 37 | + self._radio_coeff = radio_coeff |
| 38 | + self._num_time_steps = num_time_steps |
| 39 | + self._window_size = window_size |
| 40 | + |
| 41 | + self.transform = transform |
| 42 | + |
| 43 | + self._data_keys = [f"_{_key}" for _key in data.keys()] |
| 44 | + |
| 45 | + def return_factual_data(self, flag=True): |
| 46 | + self._return_factual_data = flag |
| 47 | + |
| 48 | + def _process_input_data(self, input_data, scaling_data, data_key="default"): |
| 49 | + offset = 1 |
| 50 | + horizon = 1 |
| 51 | + |
| 52 | + mean, std = scaling_data |
| 53 | + |
| 54 | + mean["chemo_application"] = 0 |
| 55 | + mean["radio_application"] = 0 |
| 56 | + std["chemo_application"] = 1 |
| 57 | + std["radio_application"] = 1 |
| 58 | + |
| 59 | + input_means = mean[ |
| 60 | + ["cancer_volume", "patient_types", "chemo_application", "radio_application"] |
| 61 | + ].values.flatten() |
| 62 | + input_stds = std[ |
| 63 | + ["cancer_volume", "patient_types", "chemo_application", "radio_application"] |
| 64 | + ].values.flatten() |
| 65 | + |
| 66 | + # Continuous values |
| 67 | + cancer_volume = (input_data["cancer_volume"] - mean["cancer_volume"]) / std[ |
| 68 | + "cancer_volume" |
| 69 | + ] |
| 70 | + patient_types = (input_data["patient_types"] - mean["patient_types"]) / std[ |
| 71 | + "patient_types" |
| 72 | + ] |
| 73 | + |
| 74 | + patient_types = np.stack( |
| 75 | + [patient_types for t in range(cancer_volume.shape[1])], axis=1 |
| 76 | + ) |
| 77 | + |
| 78 | + # Binary application |
| 79 | + chemo_application = input_data["chemo_application"] |
| 80 | + radio_application = input_data["radio_application"] |
| 81 | + sequence_lengths = input_data["sequence_lengths"] |
| 82 | + |
| 83 | + # Convert treatments to one-hot encoding |
| 84 | + treatments = np.concatenate( |
| 85 | + [ |
| 86 | + chemo_application[:, :-offset, np.newaxis], |
| 87 | + radio_application[:, :-offset, np.newaxis], |
| 88 | + ], |
| 89 | + axis=-1, |
| 90 | + ) |
| 91 | + |
| 92 | + one_hot_treatments = np.zeros( |
| 93 | + shape=(treatments.shape[0], treatments.shape[1], 4) |
| 94 | + ) |
| 95 | + for patient_id in range(treatments.shape[0]): |
| 96 | + for timestep in range(treatments.shape[1]): |
| 97 | + if ( |
| 98 | + treatments[patient_id][timestep][0] == 0 |
| 99 | + and treatments[patient_id][timestep][1] == 0 |
| 100 | + ): |
| 101 | + one_hot_treatments[patient_id][timestep] = [1, 0, 0, 0] |
| 102 | + elif ( |
| 103 | + treatments[patient_id][timestep][0] == 1 |
| 104 | + and treatments[patient_id][timestep][1] == 0 |
| 105 | + ): |
| 106 | + one_hot_treatments[patient_id][timestep] = [0, 1, 0, 0] |
| 107 | + elif ( |
| 108 | + treatments[patient_id][timestep][0] == 0 |
| 109 | + and treatments[patient_id][timestep][1] == 1 |
| 110 | + ): |
| 111 | + one_hot_treatments[patient_id][timestep] = [0, 0, 1, 0] |
| 112 | + elif ( |
| 113 | + treatments[patient_id][timestep][0] == 1 |
| 114 | + and treatments[patient_id][timestep][1] == 1 |
| 115 | + ): |
| 116 | + one_hot_treatments[patient_id][timestep] = [0, 0, 0, 1] |
| 117 | + |
| 118 | + one_hot_previous_treatments = one_hot_treatments[:, :-1, :] |
| 119 | + |
| 120 | + current_covariates = np.concatenate( |
| 121 | + [ |
| 122 | + cancer_volume[:, :-offset, np.newaxis], |
| 123 | + patient_types[:, :-offset, np.newaxis], |
| 124 | + ], |
| 125 | + axis=-1, |
| 126 | + ) |
| 127 | + outputs = cancer_volume[:, horizon:, np.newaxis] |
| 128 | + |
| 129 | + output_means = mean[["cancer_volume"]].values.flatten()[ |
| 130 | + 0 |
| 131 | + ] # because we only need scalars here |
| 132 | + output_stds = std[["cancer_volume"]].values.flatten()[0] |
| 133 | + |
| 134 | + # Add active entires |
| 135 | + active_entries = np.zeros(outputs.shape) |
| 136 | + |
| 137 | + for i in range(sequence_lengths.shape[0]): |
| 138 | + sequence_length = int(sequence_lengths[i]) |
| 139 | + active_entries[i, :sequence_length, :] = 1 |
| 140 | + |
| 141 | + self._data[data_key]["current_covariates"] = current_covariates |
| 142 | + self._data[data_key]["previous_treatments"] = one_hot_previous_treatments |
| 143 | + self._data[data_key]["current_treatments"] = one_hot_treatments |
| 144 | + self._data[data_key]["outputs"] = outputs |
| 145 | + self._data[data_key]["active_entries"] = active_entries |
| 146 | + |
| 147 | + self._data[data_key]["unscaled_outputs"] = ( |
| 148 | + outputs * std["cancer_volume"] + mean["cancer_volume"] |
| 149 | + ) |
| 150 | + self._data[data_key]["input_means"] = input_means |
| 151 | + self._data[data_key]["inputs_stds"] = input_stds |
| 152 | + self._data[data_key]["output_means"] = output_means |
| 153 | + self._data[data_key]["output_stds"] = output_stds |
| 154 | + |
| 155 | + # this is placeholder for some RNN decoder input |
| 156 | + self._data[data_key]["init_state"] = np.zeros_like(outputs) |
| 157 | + |
| 158 | + def __len__(self): |
| 159 | + return self._data["default"]["current_covariates"].shape[0] |
| 160 | + |
| 161 | + def __getitem__(self, idx): |
| 162 | + output_keys = [ |
| 163 | + "current_covariates", |
| 164 | + "previous_treatments", |
| 165 | + "current_treatments", |
| 166 | + "outputs", |
| 167 | + "active_entries", |
| 168 | + "init_state", |
| 169 | + ] |
| 170 | + sample = [self._data["default"][key][idx] for key in output_keys] |
| 171 | + |
| 172 | + if not self._return_factual_data: |
| 173 | + return sample |
| 174 | + else: |
| 175 | + factual_sample = [self._data["factuals"][key][idx] for key in output_keys] |
| 176 | + return sample, factual_sample |
| 177 | + |
| 178 | + |
| 179 | +def sum_all(inp): |
| 180 | + summers = [] |
| 181 | + for thing in inp: |
| 182 | + summers.append(thing.sum()) |
| 183 | + print(summers) |
| 184 | + |
| 185 | + |
| 186 | +if __name__ == "__main__": |
| 187 | + data = CancerPatients(filename="/data/chrisl/CRN-data/training.p") |
| 188 | + sample = data[0] |
| 189 | + print("Training sample:") |
| 190 | + sum_all(sample) |
| 191 | + |
| 192 | + test_data = CancerPatients(filename="/data/chrisl/CRN-data/test.p") |
| 193 | + tample = test_data[0] |
| 194 | + print("Test sample:") |
| 195 | + print(len(tample)) |
| 196 | + sum_all(tample[0]) |
| 197 | + sum_all(tample[1]) |
| 198 | + |
| 199 | + test_data.return_factual_data(False) |
| 200 | + ntample = test_data[0] |
| 201 | + print("Test factual sample:") |
| 202 | + print(len(ntample)) |
| 203 | + sum_all(ntample) |
| 204 | + |
| 205 | + # import pdb |
| 206 | + |
| 207 | + # pdb.set_trace() |
0 commit comments