Skip to content

Commit ed23e44

Browse files
committed
commit
1 parent 7fc8cfe commit ed23e44

File tree

17 files changed

+535
-0
lines changed

17 files changed

+535
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include <string>
4+
#include <vector>
5+
6+
#include <torch/torch.h>
7+
8+
#include "cpprl/algorithms/algorithm.h"
9+
10+
namespace cpprl
11+
{
12+
class Policy;
13+
class ROlloutStorage;
14+
15+
class A2C : public Algorithm
16+
{
17+
private:
18+
Policy &policy;
19+
float value_loss_coef, entropy_coef, max_grad_norm;
20+
std::unique_ptr<torch::optim::Optimizer> optimizer;
21+
22+
public:
23+
A2C(Policy &policy,
24+
float value_loss_coef,
25+
float entropy_coef,
26+
float learning_rate,
27+
float epsilon = 1e-8,
28+
float alpha = 0.99,
29+
float max_grad_norm = 0.5);
30+
31+
std::vector<UpdateDatum> update(RolloutStorage &rollouts);
32+
};
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <string>
4+
#include <vector>
5+
6+
#include "cpprl/storage.h"
7+
8+
namespace cpprl
9+
{
10+
struct UpdateDatum
11+
{
12+
std::string name;
13+
float value;
14+
};
15+
16+
class Algorithm
17+
{
18+
public:
19+
virtual ~Algorithm() = 0;
20+
21+
virtual std::vector<UpdateDatum> update(RolloutStorage &rollouts) = 0;
22+
};
23+
24+
inline Algorithm::~Algorithm() {}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
#include <string>
4+
#include <vector>
5+
6+
#include <torch/torch.h>
7+
8+
#include "cpprl/algorithms/algorithm.h"
9+
10+
namespace cpprl
11+
{
12+
class Policy;
13+
class ROlloutStorage;
14+
15+
class PPO : public Algorithm
16+
{
17+
private:
18+
Policy &policy;
19+
float clip_param, value_loss_coef, entropy_coef, max_grad_norm;
20+
int num_epoch, num_mini_batch;
21+
std::unique_ptr<torch::optim::Optimizer> optimizer;
22+
23+
public:
24+
PPO(Policy &policy,
25+
float clip_param,
26+
int num_epoch,
27+
int num_mini_batch,
28+
float value_loss_coef,
29+
float entropy_coef,
30+
float learning_rate,
31+
float epsilon = 1e-8,
32+
float max_grad_norm = 0.5);
33+
34+
std::vector<UpdateDatum> update(RolloutStorage &rollouts);
35+
};
36+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include "cpprl/algorithms/a2c.h"
2+
#include "cpprl/algorithms/algorithm.h"
3+
#include "cpprl/algorithms/ppo.h"
4+
#include "cpprl/distributions/distribution.h"
5+
#include "cpprl/distributions/categorical.h"
6+
#include "cpprl/generators/generator.h"
7+
#include "cpprl/generators/feed_forward_generator.h"
8+
#include "cpprl/model/cnn_base.h"
9+
#include "cpprl/model/mlp_base.h"
10+
#include "cpprl/model/model_utils.h"
11+
#include "cpprl/model/nn_base.h"
12+
#include "cpprl/model/output_layers.h"
13+
#include "cpprl/model/policy.h"
14+
#include "cpprl/spaces.h"
15+
#include "cpprl/storage.h"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include <c10/util/ArrayRef.h>
4+
#include <torch/torch.h>
5+
6+
#include "cpprl/distributions/distribution.h"
7+
8+
namespace cpprl
9+
{
10+
class Categorical : public Distribution
11+
{
12+
private:
13+
torch::Tensor probs;
14+
torch::Tensor logits;
15+
std::vector<long> batch_shape;
16+
std::vector<long> event_shape;
17+
torch::Tensor param;
18+
int num_events;
19+
20+
std::vector<long> extended_shape(c10::ArrayRef<int64_t> sample_shape);
21+
22+
public:
23+
Categorical(const torch::Tensor *probs, const torch::Tensor *logits);
24+
25+
torch::Tensor entropy();
26+
torch::Tensor log_prob(torch::Tensor value);
27+
torch::Tensor sample(c10::ArrayRef<int64_t> sample_shape = {});
28+
29+
inline torch::Tensor get_logits() { return logits; }
30+
inline torch::Tensor get_probs() { return probs; }
31+
};
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
5+
namespace cpprl
6+
{
7+
class Distribution
8+
{
9+
public:
10+
virtual ~Distribution() = 0;
11+
12+
virtual torch::Tensor entropy() = 0;
13+
virtual torch::Tensor get_logits() = 0;
14+
virtual torch::Tensor get_probs() = 0;
15+
virtual torch::Tensor log_prob(torch::Tensor value) = 0;
16+
virtual torch::Tensor sample(c10::ArrayRef<int64_t> sample_shape = {}) = 0;
17+
};
18+
19+
inline Distribution::~Distribution() {}
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
5+
#include "cpprl/generators/generator.h"
6+
7+
namespace cpprl
8+
{
9+
class FeedForwardGenerator : public Generator
10+
{
11+
private:
12+
torch::Tensor observations, hidden_states, actions, value_predictions,
13+
returns, masks, action_log_probs, advantages, indices;
14+
int index;
15+
16+
public:
17+
FeedForwardGenerator(int mini_batch_size,
18+
torch::Tensor observations,
19+
torch::Tensor hidden_states,
20+
torch::Tensor actions,
21+
torch::Tensor value_predictions,
22+
torch::Tensor returns,
23+
torch::Tensor masks,
24+
torch::Tensor action_log_probs,
25+
torch::Tensor advantages);
26+
27+
virtual bool done() const;
28+
virtual MiniBatch next();
29+
};
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include <torch/torch.h>
6+
7+
namespace cpprl
8+
{
9+
struct MiniBatch
10+
{
11+
torch::Tensor observations, hidden_states, actions, value_predictions,
12+
returns, masks, action_log_probs, advantages;
13+
14+
MiniBatch() {}
15+
MiniBatch(torch::Tensor observations,
16+
torch::Tensor hidden_states,
17+
torch::Tensor actions,
18+
torch::Tensor value_predictions,
19+
torch::Tensor returns,
20+
torch::Tensor masks,
21+
torch::Tensor action_log_probs,
22+
torch::Tensor advantages)
23+
: observations(observations),
24+
hidden_states(hidden_states),
25+
actions(actions),
26+
value_predictions(value_predictions),
27+
returns(returns),
28+
masks(masks),
29+
action_log_probs(action_log_probs),
30+
advantages(advantages) {}
31+
};
32+
33+
class Generator
34+
{
35+
public:
36+
virtual ~Generator() = 0;
37+
38+
virtual bool done() const = 0;
39+
virtual MiniBatch next() = 0;
40+
};
41+
42+
inline Generator::~Generator() {}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
5+
#include "cpprl/generators/generator.h"
6+
7+
namespace cpprl
8+
{
9+
class RecurrentGenerator : public Generator
10+
{
11+
private:
12+
torch::Tensor observations, hidden_states, actions, value_predictions,
13+
returns, masks, action_log_probs, advantages, indices;
14+
int index, num_envs_per_batch;
15+
16+
public:
17+
RecurrentGenerator(int num_processes,
18+
int num_mini_batch,
19+
torch::Tensor observations,
20+
torch::Tensor hidden_states,
21+
torch::Tensor actions,
22+
torch::Tensor value_predictions,
23+
torch::Tensor returns,
24+
torch::Tensor masks,
25+
torch::Tensor action_log_probs,
26+
torch::Tensor advantages);
27+
28+
virtual bool done() const;
29+
virtual MiniBatch next();
30+
};
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include <torch/torch.h>
6+
7+
#include "cpprl/model/nn_base.h"
8+
9+
using namespace torch;
10+
11+
namespace cpprl
12+
{
13+
class CnnBase : public NNBase
14+
{
15+
private:
16+
nn::Sequential main;
17+
nn::Sequential critic_linear;
18+
19+
public:
20+
CnnBase(unsigned int num_inputs,
21+
bool recurrent = false,
22+
unsigned int hidden_size = 512);
23+
24+
std::vector<torch::Tensor> forward(torch::Tensor inputs,
25+
torch::Tensor hxs,
26+
torch::Tensor masks);
27+
};
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include <torch/torch.h>
6+
7+
#include "cpprl/model/nn_base.h"
8+
9+
using namespace torch;
10+
11+
namespace cpprl
12+
{
13+
class MlpBase : public NNBase
14+
{
15+
private:
16+
nn::Sequential actor;
17+
nn::Sequential critic;
18+
nn::Linear critic_linear;
19+
20+
public:
21+
MlpBase(unsigned int num_inputs,
22+
bool recurrent = false,
23+
unsigned int hidden_size = 64);
24+
25+
std::vector<torch::Tensor> forward(torch::Tensor inputs,
26+
torch::Tensor hxs,
27+
torch::Tensor masks);
28+
};
29+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include <torch/torch.h>
6+
7+
using namespace torch;
8+
9+
namespace cpprl
10+
{
11+
struct FlattenImpl : nn::Module
12+
{
13+
torch::Tensor forward(torch::Tensor x);
14+
};
15+
TORCH_MODULE(Flatten);
16+
17+
void init_weights(torch::OrderedDict<std::string, torch::Tensor> parameters,
18+
double weight_gain,
19+
double bias_gain);
20+
}

0 commit comments

Comments
 (0)