|
| 1 | +""" |
| 2 | +Test: |
| 3 | +Local test on train.py |
| 4 | +python train.py --train "../../test_data/train/" --validation "../../test_data/val/" --model-dir "../../test_data/" |
| 5 | +
|
| 6 | +vscode launch.json |
| 7 | +{ |
| 8 | + "version": "0.2.0", |
| 9 | + "configurations": [ |
| 10 | + { |
| 11 | + "name": "Python: Current File", |
| 12 | + "type": "python", |
| 13 | + "request": "launch", |
| 14 | + "program": "${file}", |
| 15 | + "console": "integratedTerminal", |
| 16 | + "cwd": "${fileDirname}", |
| 17 | + "args": [ |
| 18 | + "--train", |
| 19 | + "../../test_data/train/", |
| 20 | + "--validation", |
| 21 | + "../../test_data/val/", |
| 22 | + "--model-dir", |
| 23 | + "../../test_data/" |
| 24 | + ] |
| 25 | + } |
| 26 | + ] |
| 27 | +} |
| 28 | +
|
| 29 | +""" |
| 30 | +from sklearn.datasets import make_classification |
| 31 | +import argparse |
| 32 | +import torch |
| 33 | +import torch.nn as nn |
| 34 | +import torch.nn.functional as F |
| 35 | +import torch.optim as optim |
| 36 | +from torch.optim.lr_scheduler import StepLR |
| 37 | +from torch.utils.data import Dataset, DataLoader |
| 38 | +import os |
| 39 | +from utils import print_files_in_path |
| 40 | +import torchaudio |
| 41 | + |
| 42 | + |
| 43 | +class MyDataset(Dataset): |
| 44 | + def __init__(self, n_samples, n_features, n_classes): |
| 45 | + self.n_samples = n_samples |
| 46 | + self.X, self.Y = make_classification( |
| 47 | + n_samples=n_samples, |
| 48 | + n_features=n_features, |
| 49 | + n_redundant=0, |
| 50 | + n_informative=2, |
| 51 | + n_clusters_per_class=1, |
| 52 | + n_classes=n_classes, |
| 53 | + ) |
| 54 | + |
| 55 | + def __len__(self): |
| 56 | + return self.n_samples |
| 57 | + |
| 58 | + def __getitem__(self, x): |
| 59 | + # Model expect float32 |
| 60 | + return torch.tensor(self.X[x, :], dtype=torch.float32), torch.tensor(self.Y[x], dtype=torch.long) |
| 61 | + |
| 62 | + |
| 63 | +class Net(nn.Module): |
| 64 | + def __init__(self, input_features): |
| 65 | + super(Net, self).__init__() |
| 66 | + self.fc1 = nn.Linear(input_features, 3) |
| 67 | + |
| 68 | + def forward(self, x): |
| 69 | + x = self.fc1(x) |
| 70 | + output = F.log_softmax(x, dim=1) |
| 71 | + return output |
| 72 | + |
| 73 | + |
| 74 | +def train(args, model, device, train_loader, optimizer, epoch): |
| 75 | + model.train() |
| 76 | + for batch_idx, (data, target) in enumerate(train_loader): |
| 77 | + data, target = data.to(device), target.to(device) |
| 78 | + optimizer.zero_grad() |
| 79 | + output = model(data) |
| 80 | + loss = F.nll_loss(output, target.long()) |
| 81 | + loss.backward() |
| 82 | + optimizer.step() |
| 83 | + if batch_idx % args.log_interval == 0: |
| 84 | + print( |
| 85 | + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( |
| 86 | + epoch, |
| 87 | + batch_idx * len(data), |
| 88 | + len(train_loader.dataset), |
| 89 | + 100.0 * batch_idx / len(train_loader), |
| 90 | + loss.item(), |
| 91 | + ) |
| 92 | + ) |
| 93 | + |
| 94 | + |
| 95 | +def test(model, device, test_loader): |
| 96 | + model.eval() |
| 97 | + test_loss = 0 |
| 98 | + correct = 0 |
| 99 | + with torch.no_grad(): |
| 100 | + for data, target in test_loader: |
| 101 | + data, target = data.to(device), target.to(device) |
| 102 | + output = model(data) |
| 103 | + test_loss += F.nll_loss(output, target.long(), reduction="sum").item() # sum up batch loss |
| 104 | + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability |
| 105 | + correct += pred.eq(target.view_as(pred)).sum().item() |
| 106 | + |
| 107 | + test_loss /= len(test_loader.dataset) |
| 108 | + |
| 109 | + print( |
| 110 | + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( |
| 111 | + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) |
| 112 | + ) |
| 113 | + ) |
| 114 | + |
| 115 | + |
| 116 | +# Sagemaker |
| 117 | +def model_fn(model_dir): |
| 118 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 119 | + model = Net() |
| 120 | + if torch.cuda.device_count() > 1: |
| 121 | + print("Gpu count: {}".format(torch.cuda.device_count())) |
| 122 | + model = nn.DataParallel(model) |
| 123 | + |
| 124 | + with open(os.path.join(model_dir, "model.pth"), "rb") as f: |
| 125 | + model.load_state_dict(torch.load(f)) |
| 126 | + return model.to(device) |
| 127 | + |
| 128 | + |
| 129 | +# Sagemaker |
| 130 | +def save_model(model, model_dir): |
| 131 | + path = os.path.join(model_dir, "model.pth") |
| 132 | + torch.save(model.state_dict(), path) |
| 133 | + |
| 134 | + |
| 135 | +def main(args): |
| 136 | + """ |
| 137 | + SM_CHANNEL does not contain backward slash: |
| 138 | + SM_CHANNEL_TRAIN=/opt/ml/input/data/train |
| 139 | + SM_CHANNEL_VALIDATION=/opt/ml/input/data/validation |
| 140 | +
|
| 141 | + Training job name: |
| 142 | + script-mode-container-xgb-2020-08-10-13-29-15-756 |
| 143 | +
|
| 144 | + """ |
| 145 | + train_channel, validation_channel, model_dir = args.train, args.validation, args.model_dir |
| 146 | + |
| 147 | + print("\nList of files in train channel: ") |
| 148 | + print_files_in_path(train_channel) |
| 149 | + |
| 150 | + print("\nList of files in validation channel: ") |
| 151 | + print_files_in_path(validation_channel) |
| 152 | + use_cuda = torch.cuda.is_available() |
| 153 | + torch.manual_seed(args.seed) |
| 154 | + |
| 155 | + device = torch.device("cuda" if use_cuda else "cpu") |
| 156 | + print("Device:", device) |
| 157 | + kwargs = {"num_workers": 8, "pin_memory": True} if use_cuda else {} |
| 158 | + |
| 159 | + input_features = 5 |
| 160 | + n_samples = 1000 |
| 161 | + dataset = MyDataset(n_samples, input_features, 3) |
| 162 | + train_len = int(n_samples * 0.7) |
| 163 | + test_len = n_samples - train_len |
| 164 | + train_set, val_set = torch.utils.data.random_split(dataset, [train_len, test_len]) |
| 165 | + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) |
| 166 | + test_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, **kwargs) |
| 167 | + |
| 168 | + model = Net(input_features).to(device) |
| 169 | + # optimizer = optim.Adadelta(model.parameters(), lr=args.lr) |
| 170 | + optimizer = optim.Adam(model.parameters(), lr=args.lr) |
| 171 | + # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) |
| 172 | + for epoch in range(1, args.epochs + 1): |
| 173 | + train(args, model, device, train_loader, optimizer, epoch) |
| 174 | + # scheduler.step() |
| 175 | + |
| 176 | + test(model, device, test_loader) |
| 177 | + |
| 178 | + if args.save_model: |
| 179 | + save_model(model, model_dir) |
| 180 | + |
| 181 | + |
| 182 | +if __name__ == "__main__": |
| 183 | + |
| 184 | + # Training settings |
| 185 | + parser = argparse.ArgumentParser() |
| 186 | + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") |
| 187 | + parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 14)") |
| 188 | + parser.add_argument( |
| 189 | + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" |
| 190 | + ) |
| 191 | + parser.add_argument("--lr", type=float, default=0.1, metavar="LR", help="learning rate (default: 1.0)") |
| 192 | + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") |
| 193 | + parser.add_argument( |
| 194 | + "--log-interval", |
| 195 | + type=int, |
| 196 | + default=10, |
| 197 | + metavar="N", |
| 198 | + help="how many batches to wait before logging training status", |
| 199 | + ) |
| 200 | + |
| 201 | + # This is a way to pass additional arguments when running as a script |
| 202 | + # and use sagemaker-containers defaults to set their values when not specified. |
| 203 | + parser.add_argument("--train", type=str, default=os.getenv("SM_CHANNEL_TRAIN", None)) |
| 204 | + parser.add_argument("--validation", type=str, default=os.getenv("SM_CHANNEL_VALIDATION", None)) |
| 205 | + parser.add_argument("--model-dir", type=str, default=os.getenv("SM_MODEL_DIR", None)) |
| 206 | + |
| 207 | + args = parser.parse_args() |
| 208 | + print(args) |
| 209 | + main(args) |
0 commit comments