|
5 | 5 | from uuid import uuid4 |
6 | 6 |
|
7 | 7 | import numpy as np |
| 8 | +import pandas as pd |
8 | 9 |
|
9 | 10 | from data_manager import file_processor |
| 11 | +from returns_quantization import add_returns_in_place |
10 | 12 | from utils import * |
11 | 13 |
|
| 14 | +np.set_printoptions(threshold=np.nan) |
| 15 | +pd.set_option('display.height', 1000) |
| 16 | +pd.set_option('display.max_rows', 500) |
| 17 | +pd.set_option('display.max_columns', 500) |
| 18 | +pd.set_option('display.width', 1000) |
| 19 | + |
| 20 | + |
| 21 | +def generate_quantiles(data_folder, bitcoin_file): |
| 22 | + def get_label(btc_df, btc_slice, i, slice_size): |
| 23 | + class_name = str(btc_df[i + slice_size:i + slice_size + 1]['close_price_returns_labels'].values[0]) |
| 24 | + return class_name |
| 25 | + |
| 26 | + return generate_cnn_dataset(data_folder, bitcoin_file, get_label) |
| 27 | + |
| 28 | + |
| 29 | +def generate_up_down(data_folder, bitcoin_file): |
| 30 | + def get_price_direction(btc_df, btc_slice, i, slice_size): |
| 31 | + last_price = btc_slice[-2:-1]['price_close'].values[0] |
| 32 | + next_price = btc_df[i + slice_size:i + slice_size + 1]['price_close'].values[0] |
| 33 | + if last_price < next_price: |
| 34 | + class_name = 'UP' |
| 35 | + else: |
| 36 | + class_name = 'DOWN' |
| 37 | + return class_name |
| 38 | + |
| 39 | + return generate_cnn_dataset(data_folder, bitcoin_file, get_price_direction) |
| 40 | + |
| 41 | + |
| 42 | +def generate_cnn_dataset(data_folder, bitcoin_file, get_class_name): |
| 43 | + btc_df = file_processor(bitcoin_file) |
| 44 | + btc_df, levels = add_returns_in_place(btc_df) |
| 45 | + |
| 46 | + print('-' * 80) |
| 47 | + print('Those values should be roughly equal to 1/len(levels):') |
| 48 | + for ii in range(len(levels)): |
| 49 | + print(ii, np.mean((btc_df['close_price_returns_labels'] == ii).values)) |
| 50 | + print(levels) |
| 51 | + print('-' * 80) |
12 | 52 |
|
13 | | -def generate(data_folder, bitcoin_file): |
14 | | - p = file_processor(bitcoin_file) |
15 | 53 | slice_size = 40 |
16 | 54 | test_every_steps = 10 |
17 | | - n = len(p) - slice_size |
| 55 | + n = len(btc_df) - slice_size |
18 | 56 |
|
19 | 57 | shutil.rmtree(data_folder, ignore_errors=True) |
20 | 58 | for epoch in range(int(1e6)): |
21 | 59 | st = time() |
22 | 60 |
|
23 | 61 | i = np.random.choice(n) |
24 | | - sl = p[i:i + slice_size] |
| 62 | + btc_slice = btc_df[i:i + slice_size] |
25 | 63 |
|
26 | | - if sl.isnull().values.any(): |
| 64 | + if btc_slice.isnull().values.any(): |
27 | 65 | # sometimes prices are discontinuous and nothing happened in one 5min bucket. |
28 | | - # in that case, we consider this slice as wrong and we ask for a new one. |
| 66 | + # in that case, we consider this slice as wrong and we raise an exception. |
29 | 67 | # it's likely to happen at the beginning of the data set where the volumes are low. |
30 | | - continue |
31 | | - |
32 | | - last_price = sl[-2:-1]['price_close'].values[0] |
33 | | - next_price = p[i + slice_size:i + slice_size + 1]['price_close'].values[0] |
34 | | - |
35 | | - if last_price < next_price: |
36 | | - direction = 'UP' |
37 | | - else: |
38 | | - direction = 'DOWN' |
| 68 | + raise Exception('NaN values detected. Please remove them.') |
39 | 69 |
|
40 | | - save_dir = os.path.join(data_folder, 'train', direction) |
| 70 | + class_name = get_class_name(btc_df, btc_slice, i, slice_size) |
| 71 | + save_dir = os.path.join(data_folder, 'train', class_name) |
41 | 72 | if epoch % test_every_steps == 0: |
42 | | - save_dir = os.path.join(data_folder, 'test', direction) |
| 73 | + save_dir = os.path.join(data_folder, 'test', class_name) |
43 | 74 | mkdir_p(save_dir) |
44 | | - save_to_file(sl, filename=save_dir + '/' + str(uuid4()) + '.png') |
45 | | - |
46 | | - print('epoch = {0}, time = {1:.3f}'.format(str(epoch).zfill(8), time() - st)) |
| 75 | + filename = save_dir + '/' + str(uuid4()) + '.png' |
| 76 | + save_to_file(btc_slice, filename=filename) |
| 77 | + print('epoch = {0}, time = {1:.3f}, filename = {2}'.format(str(epoch).zfill(8), time() - st, filename)) |
47 | 78 |
|
48 | 79 |
|
49 | 80 | def main(): |
50 | | - arg = sys.argv |
51 | | - assert len(arg) == 3, 'Usage: python3 {} DATA_FOLDER_TO_STORE_GENERATED_DATASET ' \ |
52 | | - 'BITCOIN_MARKET_DATA_CSV_PATH'.format(arg[0]) |
53 | | - data_folder = arg[1] |
54 | | - bitcoin_file = arg[2] |
55 | | - generate(data_folder, bitcoin_file) |
| 81 | + args = sys.argv |
| 82 | + assert len(args) == 4, 'Usage: python3 {} DATA_FOLDER_TO_STORE_GENERATED_DATASET ' \ |
| 83 | + 'BITCOIN_MARKET_DATA_CSV_PATH USE_QUANTILES'.format(args[0]) |
| 84 | + data_folder = args[1] |
| 85 | + bitcoin_file = args[2] |
| 86 | + use_quantiles = int(args[3]) |
| 87 | + |
| 88 | + data_gen_func = generate_quantiles if use_quantiles else generate_up_down |
| 89 | + print('Using: {}'.format(data_gen_func)) |
| 90 | + data_gen_func(data_folder, bitcoin_file) |
56 | 91 |
|
57 | 92 |
|
58 | 93 | if __name__ == '__main__': |
|
0 commit comments