Skip to content

Commit ba306ac

Browse files
authored
Merge pull request #14 from DanielAvdar/init
Init
2 parents 2c74081 + dd5bdf6 commit ba306ac

15 files changed

+123
-35
lines changed

docs/source/api.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
API Reference
2-
============
2+
=============
33

44
This section provides detailed information about the prob-spaces API.
55

66
Probability Spaces
7-
-----------------
7+
------------------
88

99
.. automodule:: prob_spaces
1010
:members:
@@ -20,7 +20,7 @@ Converter
2020
:show-inheritance:
2121

2222
Distributions
23-
------------
23+
-------------
2424

2525
.. automodule:: prob_spaces.dists
2626
:members:

docs/source/index.rst

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,40 @@ environments.
1919
:target: https://opensource.org/licenses/MIT
2020
:alt: License
2121

22+
.. image:: https://img.shields.io/badge/ubuntu-blue?logo=ubuntu
23+
:alt: Ubuntu
24+
25+
.. image:: https://img.shields.io/badge/ubuntu-blue?logo=windows
26+
:alt: Windows
27+
28+
.. image:: https://img.shields.io/badge/ubuntu-blue?logo=apple
29+
:alt: MacOS
30+
31+
32+
2233
.. toctree::
23-
:maxdepth: 2
24-
:caption: Contents:
25-
26-
introduction
27-
installation
28-
api
29-
modules/discrete
30-
modules/multi_discrete
31-
modules/box
32-
modules/dict
33-
modules/converter
34+
:maxdepth: 1
35+
:caption: Introduction:
36+
37+
./introduction
38+
./installation
39+
40+
.. toctree::
41+
:maxdepth: 1
42+
:caption: Usage:
43+
44+
./modules/discrete
45+
./modules/multi_discrete
46+
./modules/box
47+
./modules/dict
48+
./modules/converter
49+
50+
.. toctree::
51+
:maxdepth: 1
52+
:caption: API Reference:
53+
54+
./api
55+
3456

3557
Indices and tables
3658
==================

docs/source/installation.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _installation:
2+
13
Installation
24
===========
35

docs/source/introduction.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _introduction:
2+
13
Introduction
24
===========
35

docs/source/modules/box.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
.. _box:
2+
13
Box Space
2-
========
4+
=========
35

46
The ``BoxDist`` class extends the Gymnasium Box space to create continuous probability distributions.
57

@@ -14,7 +16,7 @@ API Reference
1416
-------------
1517

1618
.. autoclass:: prob_spaces.box.BoxDist
17-
:members:
19+
:members: __call__
1820
:undoc-members:
1921
:show-inheritance:
2022

docs/source/modules/converter.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _converter:
2+
13
Space Converter
24
===============
35

docs/source/modules/dict.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
.. _dict:
2+
13
Dict Space
2-
=========
4+
==========
35

46
The ``DictDist`` class extends the Gymnasium Dict space to create nested distributions.
57

@@ -13,7 +15,7 @@ API Reference
1315
------------
1416

1517
.. autoclass:: prob_spaces.dict.DictDist
16-
:members:
18+
:members: __call__
1719
:undoc-members:
1820
:show-inheritance:
1921

docs/source/modules/discrete.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
.. _discrete:
2+
13
Discrete Space
2-
=============
4+
==============
35

46
The ``DiscreteDist`` class extends the Gymnasium Discrete space to create categorical distributions.
57

@@ -13,7 +15,7 @@ API Reference
1315
------------
1416

1517
.. autoclass:: prob_spaces.discrete.DiscreteDist
16-
:members:
18+
:members: __call__
1719
:undoc-members:
1820
:show-inheritance:
1921

docs/source/modules/multi_discrete.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
.. _multi_discrete:
2+
13
MultiDiscrete Space
2-
==================
4+
===================
35

46
The ``MultiDiscreteDist`` class extends the Gymnasium MultiDiscrete space to create categorical distributions
57
for multiple discrete variables.
@@ -14,7 +16,7 @@ API Reference
1416
------------
1517

1618
.. autoclass:: prob_spaces.multi_discrete.MultiDiscreteDist
17-
:members:
19+
:members: __call__
1820
:undoc-members:
1921
:show-inheritance:
2022

prob_spaces/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Probability distribution classes for various Gymnasium spaces."""
2+
3+
from .converter import convert_to_prob_space
4+
from .dict import DictDist
5+
from .discrete import DiscreteDist
6+
from .multi_discrete import MultiDiscreteDist
7+
8+
__all__ = [
9+
"MultiDiscreteDist",
10+
"DiscreteDist",
11+
"DictDist",
12+
"convert_to_prob_space",
13+
]

prob_spaces/box.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ def transforms(self, device: th.device) -> list:
3333
return transforms
3434

3535
def __call__(self, loc: th.Tensor, scale: th.Tensor) -> th.distributions.Distribution:
36+
"""
37+
Generates a transformed probability distribution based on the input location and scale
38+
parameters. The method constructs a base distribution, applies a sequence of
39+
transformations to it, and returns the resulting transformed distribution. This
40+
allows for creating flexible and expressive probability distributions.
41+
42+
:param loc: A tensor specifying the location parameters for the base distribution.
43+
:param scale: A tensor specifying the scale parameters for the base distribution.
44+
:return: A transformed distribution object derived from the specified base distribution
45+
and transformations.
46+
"""
3647
dist = self.base_dist(loc, scale, validate_args=True) # type: ignore
3748
transforms = self.transforms(loc.device)
3849
transformed_dist = TransformedDistribution(dist, transforms, validate_args=True)

prob_spaces/discrete.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66

77
class DiscreteDist(spaces.Discrete):
88
def __call__(self, prob: th.Tensor, mask: th.Tensor = None) -> MaskedCategorical:
9+
"""
10+
Compute and return a masked categorical distribution based on the given probability
11+
tensor and an optional mask. The distribution incorporates specific probabilities
12+
and constraints defined by the provided input.
13+
14+
:param prob: A tensor representing the probabilities for each category.
15+
:param mask: A tensor specifying a mask to limit the valid categories.
16+
Defaults to a tensor of ones if not provided.
17+
:return: A MaskedCategorical distribution constructed with given probabilities,
18+
mask, and starting values.
19+
"""
920
probs = prob.reshape(self.n) # type: ignore
1021
start = self.start
1122
mask = mask if mask is not None else th.ones_like(probs, dtype=th.bool, device=probs.device)

prob_spaces/multi_discrete.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,23 @@ def _internal_mask(self) -> NDArray[np.bool_]:
3737
return mask
3838

3939
def __call__(self, prob: th.Tensor, mask: th.Tensor = None) -> MaskedCategorical:
40+
"""
41+
Applies a transformation to the input probability tensor and optional mask, creating
42+
a `MaskedCategorical` distribution. The method reshapes the input probabilities
43+
to match the specified `nvec` dimensions, applies an optional mask for masking
44+
specific probabilities, and combines these with an internal mask. The result
45+
is used to create a `MaskedCategorical` distribution.
46+
47+
:param prob: A tensor containing probabilities to be reshaped and used in
48+
constructing the distribution.
49+
:type prob: th.Tensor
50+
:param mask: An optional boolean tensor for masking specific probabilities
51+
before creating the distribution. Defaults to None.
52+
:type mask: th.Tensor, optional
53+
:return: A `MaskedCategorical` distribution object created with reshaped
54+
probabilities and combined masking information.
55+
:rtype: MaskedCategorical
56+
"""
4057
probs = prob.reshape(*self.nvec.shape, self.prob_last_dim)
4158
start = self.start
4259
mask = mask if mask is not None else th.ones_like(probs, dtype=th.bool, device=probs.device)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ authors = [
77
]
88
license = { text = "MIT" }
99
readme = "README.md"
10-
requires-python = ">=3.10,<4"
10+
requires-python = ">=3.10"
1111

1212
keywords = [
1313
"python"

uv.lock

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)