Skip to content

Commit 124feaa

Browse files
author
Chris Lucas
committed
first commit
1 parent 55d01f0 commit 124feaa

File tree

8 files changed

+1682
-0
lines changed

8 files changed

+1682
-0
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
__pycache__
2+
.DS_Store
3+
*.p
4+
*.png
5+
*.pdf
6+
7+
results
8+
simulation_cache

dataset/__init__.py

Whitespace-only changes.

dataset/cancer_loader.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from collections import defaultdict
2+
3+
import pickle
4+
5+
import numpy as np
6+
7+
from torch.utils.data import Dataset
8+
9+
10+
class CancerPatients(Dataset):
11+
def __init__(
12+
self,
13+
data,
14+
scaling_data,
15+
chemo_coeff,
16+
radio_coeff,
17+
num_time_steps,
18+
window_size,
19+
factuals=None,
20+
transform=None,
21+
):
22+
self._data = defaultdict(dict)
23+
24+
self._process_input_data(
25+
input_data=data, scaling_data=scaling_data,
26+
)
27+
28+
self._return_factual_data = False
29+
30+
if factuals is not None:
31+
self._process_input_data(
32+
input_data=factuals, scaling_data=scaling_data, data_key="factuals",
33+
)
34+
self._return_factual_data = True
35+
36+
self._chemo_coeff = chemo_coeff
37+
self._radio_coeff = radio_coeff
38+
self._num_time_steps = num_time_steps
39+
self._window_size = window_size
40+
41+
self.transform = transform
42+
43+
self._data_keys = [f"_{_key}" for _key in data.keys()]
44+
45+
def return_factual_data(self, flag=True):
46+
self._return_factual_data = flag
47+
48+
def _process_input_data(self, input_data, scaling_data, data_key="default"):
49+
offset = 1
50+
horizon = 1
51+
52+
mean, std = scaling_data
53+
54+
mean["chemo_application"] = 0
55+
mean["radio_application"] = 0
56+
std["chemo_application"] = 1
57+
std["radio_application"] = 1
58+
59+
input_means = mean[
60+
["cancer_volume", "patient_types", "chemo_application", "radio_application"]
61+
].values.flatten()
62+
input_stds = std[
63+
["cancer_volume", "patient_types", "chemo_application", "radio_application"]
64+
].values.flatten()
65+
66+
# Continuous values
67+
cancer_volume = (input_data["cancer_volume"] - mean["cancer_volume"]) / std[
68+
"cancer_volume"
69+
]
70+
patient_types = (input_data["patient_types"] - mean["patient_types"]) / std[
71+
"patient_types"
72+
]
73+
74+
patient_types = np.stack(
75+
[patient_types for t in range(cancer_volume.shape[1])], axis=1
76+
)
77+
78+
# Binary application
79+
chemo_application = input_data["chemo_application"]
80+
radio_application = input_data["radio_application"]
81+
sequence_lengths = input_data["sequence_lengths"]
82+
83+
# Convert treatments to one-hot encoding
84+
treatments = np.concatenate(
85+
[
86+
chemo_application[:, :-offset, np.newaxis],
87+
radio_application[:, :-offset, np.newaxis],
88+
],
89+
axis=-1,
90+
)
91+
92+
one_hot_treatments = np.zeros(
93+
shape=(treatments.shape[0], treatments.shape[1], 4)
94+
)
95+
for patient_id in range(treatments.shape[0]):
96+
for timestep in range(treatments.shape[1]):
97+
if (
98+
treatments[patient_id][timestep][0] == 0
99+
and treatments[patient_id][timestep][1] == 0
100+
):
101+
one_hot_treatments[patient_id][timestep] = [1, 0, 0, 0]
102+
elif (
103+
treatments[patient_id][timestep][0] == 1
104+
and treatments[patient_id][timestep][1] == 0
105+
):
106+
one_hot_treatments[patient_id][timestep] = [0, 1, 0, 0]
107+
elif (
108+
treatments[patient_id][timestep][0] == 0
109+
and treatments[patient_id][timestep][1] == 1
110+
):
111+
one_hot_treatments[patient_id][timestep] = [0, 0, 1, 0]
112+
elif (
113+
treatments[patient_id][timestep][0] == 1
114+
and treatments[patient_id][timestep][1] == 1
115+
):
116+
one_hot_treatments[patient_id][timestep] = [0, 0, 0, 1]
117+
118+
one_hot_previous_treatments = one_hot_treatments[:, :-1, :]
119+
120+
current_covariates = np.concatenate(
121+
[
122+
cancer_volume[:, :-offset, np.newaxis],
123+
patient_types[:, :-offset, np.newaxis],
124+
],
125+
axis=-1,
126+
)
127+
outputs = cancer_volume[:, horizon:, np.newaxis]
128+
129+
output_means = mean[["cancer_volume"]].values.flatten()[
130+
0
131+
] # because we only need scalars here
132+
output_stds = std[["cancer_volume"]].values.flatten()[0]
133+
134+
# Add active entires
135+
active_entries = np.zeros(outputs.shape)
136+
137+
for i in range(sequence_lengths.shape[0]):
138+
sequence_length = int(sequence_lengths[i])
139+
active_entries[i, :sequence_length, :] = 1
140+
141+
self._data[data_key]["current_covariates"] = current_covariates
142+
self._data[data_key]["previous_treatments"] = one_hot_previous_treatments
143+
self._data[data_key]["current_treatments"] = one_hot_treatments
144+
self._data[data_key]["outputs"] = outputs
145+
self._data[data_key]["active_entries"] = active_entries
146+
147+
self._data[data_key]["unscaled_outputs"] = (
148+
outputs * std["cancer_volume"] + mean["cancer_volume"]
149+
)
150+
self._data[data_key]["input_means"] = input_means
151+
self._data[data_key]["inputs_stds"] = input_stds
152+
self._data[data_key]["output_means"] = output_means
153+
self._data[data_key]["output_stds"] = output_stds
154+
155+
# this is placeholder for some RNN decoder input
156+
self._data[data_key]["init_state"] = np.zeros_like(outputs)
157+
158+
def __len__(self):
159+
return self._data["default"]["current_covariates"].shape[0]
160+
161+
def __getitem__(self, idx):
162+
output_keys = [
163+
"current_covariates",
164+
"previous_treatments",
165+
"current_treatments",
166+
"outputs",
167+
"active_entries",
168+
"init_state",
169+
]
170+
sample = [self._data["default"][key][idx] for key in output_keys]
171+
172+
if not self._return_factual_data:
173+
return sample
174+
else:
175+
factual_sample = [self._data["factuals"][key][idx] for key in output_keys]
176+
return sample, factual_sample
177+
178+
179+
def sum_all(inp):
180+
summers = []
181+
for thing in inp:
182+
summers.append(thing.sum())
183+
print(summers)
184+
185+
186+
if __name__ == "__main__":
187+
data = CancerPatients(filename="/data/chrisl/CRN-data/training.p")
188+
sample = data[0]
189+
print("Training sample:")
190+
sum_all(sample)
191+
192+
test_data = CancerPatients(filename="/data/chrisl/CRN-data/test.p")
193+
tample = test_data[0]
194+
print("Test sample:")
195+
print(len(tample))
196+
sum_all(tample[0])
197+
sum_all(tample[1])
198+
199+
test_data.return_factual_data(False)
200+
ntample = test_data[0]
201+
print("Test factual sample:")
202+
print(len(ntample))
203+
sum_all(ntample)
204+
205+
# import pdb
206+
207+
# pdb.set_trace()

0 commit comments

Comments
 (0)