Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 5 additions & 48 deletions src/vector/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class VectorProtocol:
vector onto azimuthal, longitudinal, and temporal coordinates.
GenericClass (type): The most generic concrete class for this type, for
vectors without momentum-synonyms.
MomentumClass (type): The momentum class for this type, for vectors with
momentum-synonyms.
"""

@property
Expand Down Expand Up @@ -173,6 +175,7 @@ def _wrap_result(
ProjectionClass3D: type[VectorProtocolSpatial]
ProjectionClass4D: type[VectorProtocolLorentz]
GenericClass: type[VectorProtocol]
MomentumClass: type[VectorProtocol]

def to_Vector2D(self) -> VectorProtocolPlanar:
"""Projects this vector/these vectors onto azimuthal coordinates only."""
Expand Down Expand Up @@ -4318,50 +4321,6 @@ def _check_instance(
return any_or_all(isinstance(v, clas) for v in objects)


def _demote_handler_vector(
handler: VectorProtocol,
objects: tuple[VectorProtocol, ...],
vector_class: type[VectorProtocol],
new_vector: VectorProtocol,
) -> VectorProtocol:
"""
Demotes the handler vector to the lowest possible dimension while respecting
the priority of backends.
"""
# if all the objects are not from the same backend
# choose the {X}D object of the backend with highest priority (if it exists)
# or demote the first encountered object of the backend with highest priority to {X}D
backends = [
next(
x.__module__
for x in type(obj).__mro__
if "vector.backends." in x.__module__
)
for obj in objects
]
if len({_handler_priority.index(backend) for backend in backends}) != 1:
new_type = type(new_vector)
flag = 0
# if there is a {X}D object of the backend with highest priority
# make it the new handler
for obj in objects:
if type(obj) == new_type:
handler = obj
flag = 1
break
# else, demote the dimension of the object of the backend with highest priority
if flag == 0:
handler = new_vector
# if all objects are from the same backend
# use the {X}D one as the handler
else:
for obj in objects:
if isinstance(obj, vector_class):
handler = obj

return handler


def _handler_of(*objects: VectorProtocol) -> VectorProtocol:
"""
Determines which vector should wrap the output of a dispatched function.
Expand Down Expand Up @@ -4391,11 +4350,9 @@ def _flavor_of(*objects: VectorProtocol) -> type[VectorProtocol]:
from vector.backends.object import VectorObject

handler: VectorProtocol | None = None
is_momentum = True
is_momentum = any(isinstance(obj, Momentum) for obj in objects)
for obj in objects:
if isinstance(obj, Vector):
if not isinstance(obj, Momentum):
is_momentum = False
if handler is None:
handler = obj
elif isinstance(obj, VectorObject):
Expand All @@ -4405,6 +4362,6 @@ def _flavor_of(*objects: VectorProtocol) -> type[VectorProtocol]:

assert handler is not None
if is_momentum:
return type(handler)
return handler.MomentumClass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing this without creating MomentumClass, but that introduced a lot of edge cases and if conditions in the _flavor_of function. Using MomentumClass makes it clean and uniform for all the backends.

else:
return handler.GenericClass
12 changes: 12 additions & 0 deletions src/vector/backends/awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,61 +1527,73 @@ class MomentumRecord4D(MomentumAwkward4D, ak.Record): # type: ignore[misc]
VectorArray2D.ProjectionClass3D = VectorArray3D
VectorArray2D.ProjectionClass4D = VectorArray4D
VectorArray2D.GenericClass = VectorArray2D
VectorArray2D.MomentumClass = MomentumArray2D

VectorRecord2D.ProjectionClass2D = VectorRecord2D
VectorRecord2D.ProjectionClass3D = VectorRecord3D
VectorRecord2D.ProjectionClass4D = VectorRecord4D
VectorRecord2D.GenericClass = VectorRecord2D
VectorRecord2D.MomentumClass = MomentumRecord2D

MomentumArray2D.ProjectionClass2D = MomentumArray2D
MomentumArray2D.ProjectionClass3D = MomentumArray3D
MomentumArray2D.ProjectionClass4D = MomentumArray4D
MomentumArray2D.GenericClass = VectorArray2D
MomentumArray2D.MomentumClass = MomentumArray2D

MomentumRecord2D.ProjectionClass2D = MomentumRecord2D
MomentumRecord2D.ProjectionClass3D = MomentumRecord3D
MomentumRecord2D.ProjectionClass4D = MomentumRecord4D
MomentumRecord2D.GenericClass = VectorRecord2D
MomentumRecord2D.MomentumClass = MomentumRecord2D

VectorArray3D.ProjectionClass2D = VectorArray2D
VectorArray3D.ProjectionClass3D = VectorArray3D
VectorArray3D.ProjectionClass4D = VectorArray4D
VectorArray3D.GenericClass = VectorArray3D
VectorArray3D.MomentumClass = MomentumArray3D

VectorRecord3D.ProjectionClass2D = VectorRecord2D
VectorRecord3D.ProjectionClass3D = VectorRecord3D
VectorRecord3D.ProjectionClass4D = VectorRecord4D
VectorRecord3D.GenericClass = VectorRecord3D
VectorRecord3D.MomentumClass = MomentumRecord3D

MomentumArray3D.ProjectionClass2D = MomentumArray2D
MomentumArray3D.ProjectionClass3D = MomentumArray3D
MomentumArray3D.ProjectionClass4D = MomentumArray4D
MomentumArray3D.GenericClass = VectorArray3D
MomentumArray3D.MomentumClass = MomentumArray3D

MomentumRecord3D.ProjectionClass2D = MomentumRecord2D
MomentumRecord3D.ProjectionClass3D = MomentumRecord3D
MomentumRecord3D.ProjectionClass4D = MomentumRecord4D
MomentumRecord3D.GenericClass = VectorRecord3D
MomentumRecord3D.MomentumClass = MomentumRecord3D

VectorArray4D.ProjectionClass2D = VectorArray2D
VectorArray4D.ProjectionClass3D = VectorArray3D
VectorArray4D.ProjectionClass4D = VectorArray4D
VectorArray4D.GenericClass = VectorArray4D
VectorArray4D.MomentumClass = MomentumArray4D

VectorRecord4D.ProjectionClass2D = VectorRecord2D
VectorRecord4D.ProjectionClass3D = VectorRecord3D
VectorRecord4D.ProjectionClass4D = VectorRecord4D
VectorRecord4D.GenericClass = VectorRecord4D
VectorRecord4D.MomentumClass = MomentumRecord4D

MomentumArray4D.ProjectionClass2D = MomentumArray2D
MomentumArray4D.ProjectionClass3D = MomentumArray3D
MomentumArray4D.ProjectionClass4D = MomentumArray4D
MomentumArray4D.GenericClass = VectorArray4D
MomentumArray4D.MomentumClass = MomentumArray4D

MomentumRecord4D.ProjectionClass2D = MomentumRecord2D
MomentumRecord4D.ProjectionClass3D = MomentumRecord3D
MomentumRecord4D.ProjectionClass4D = MomentumRecord4D
MomentumRecord4D.GenericClass = VectorRecord4D
MomentumRecord4D.MomentumClass = MomentumRecord4D


# implementation of behaviors in Numba ########################################
Expand Down
6 changes: 6 additions & 0 deletions src/vector/backends/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,28 +1926,34 @@ def array(*args: typing.Any, **kwargs: typing.Any) -> VectorNumpy:
VectorNumpy2D.ProjectionClass3D = VectorNumpy3D
VectorNumpy2D.ProjectionClass4D = VectorNumpy4D
VectorNumpy2D.GenericClass = VectorNumpy2D
VectorNumpy2D.MomentumClass = MomentumNumpy2D

MomentumNumpy2D.ProjectionClass2D = MomentumNumpy2D
MomentumNumpy2D.ProjectionClass3D = MomentumNumpy3D
MomentumNumpy2D.ProjectionClass4D = MomentumNumpy4D
MomentumNumpy2D.GenericClass = VectorNumpy2D
MomentumNumpy2D.MomentumClass = MomentumNumpy2D

VectorNumpy3D.ProjectionClass2D = VectorNumpy2D
VectorNumpy3D.ProjectionClass3D = VectorNumpy3D
VectorNumpy3D.ProjectionClass4D = VectorNumpy4D
VectorNumpy3D.GenericClass = VectorNumpy3D
VectorNumpy3D.MomentumClass = MomentumNumpy3D

MomentumNumpy3D.ProjectionClass2D = MomentumNumpy2D
MomentumNumpy3D.ProjectionClass3D = MomentumNumpy3D
MomentumNumpy3D.ProjectionClass4D = MomentumNumpy4D
MomentumNumpy3D.GenericClass = VectorNumpy3D
MomentumNumpy3D.MomentumClass = MomentumNumpy3D

VectorNumpy4D.ProjectionClass2D = VectorNumpy2D
VectorNumpy4D.ProjectionClass3D = VectorNumpy3D
VectorNumpy4D.ProjectionClass4D = VectorNumpy4D
VectorNumpy4D.GenericClass = VectorNumpy4D
VectorNumpy4D.MomentumClass = MomentumNumpy4D

MomentumNumpy4D.ProjectionClass2D = MomentumNumpy2D
MomentumNumpy4D.ProjectionClass3D = MomentumNumpy3D
MomentumNumpy4D.ProjectionClass4D = MomentumNumpy4D
MomentumNumpy4D.GenericClass = VectorNumpy4D
MomentumNumpy4D.MomentumClass = MomentumNumpy4D
6 changes: 6 additions & 0 deletions src/vector/backends/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3204,28 +3204,34 @@ def obj(**coordinates: float) -> VectorObject:
VectorObject2D.ProjectionClass3D = VectorObject3D
VectorObject2D.ProjectionClass4D = VectorObject4D
VectorObject2D.GenericClass = VectorObject2D
VectorObject2D.MomentumClass = MomentumObject2D

MomentumObject2D.ProjectionClass2D = MomentumObject2D
MomentumObject2D.ProjectionClass3D = MomentumObject3D
MomentumObject2D.ProjectionClass4D = MomentumObject4D
MomentumObject2D.GenericClass = VectorObject2D
MomentumObject2D.MomentumClass = MomentumObject2D

VectorObject3D.ProjectionClass2D = VectorObject2D
VectorObject3D.ProjectionClass3D = VectorObject3D
VectorObject3D.ProjectionClass4D = VectorObject4D
VectorObject3D.GenericClass = VectorObject3D
VectorObject3D.MomentumClass = MomentumObject3D

MomentumObject3D.ProjectionClass2D = MomentumObject2D
MomentumObject3D.ProjectionClass3D = MomentumObject3D
MomentumObject3D.ProjectionClass4D = MomentumObject4D
MomentumObject3D.GenericClass = VectorObject3D
MomentumObject3D.MomentumClass = MomentumObject3D

VectorObject4D.ProjectionClass2D = VectorObject2D
VectorObject4D.ProjectionClass3D = VectorObject3D
VectorObject4D.ProjectionClass4D = VectorObject4D
VectorObject4D.GenericClass = VectorObject4D
VectorObject4D.MomentumClass = MomentumObject4D

MomentumObject4D.ProjectionClass2D = MomentumObject2D
MomentumObject4D.ProjectionClass3D = MomentumObject3D
MomentumObject4D.ProjectionClass4D = MomentumObject4D
MomentumObject4D.GenericClass = VectorObject4D
MomentumObject4D.MomentumClass = MomentumObject4D
78 changes: 50 additions & 28 deletions tests/backends/test_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

import vector
from vector import VectorObject2D
from vector.backends.awkward import (
MomentumAwkward2D,
MomentumAwkward3D,
MomentumAwkward4D,
)

ak = pytest.importorskip("awkward")

Expand Down Expand Up @@ -788,34 +793,6 @@ def test_like():
assert ak.all(v3 + v2.like(v3) == v3_v2)
assert ak.all(v2.like(v3) + v3 == v3_v2)

v2 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)

# momentum + generic = generic
# 2D + 3D.like(2D) = 2D
assert ak.all(v1 + v2.like(v1) == v1_v2)
assert ak.all(v2.like(v1) + v1 == v1_v2)
# 2D + 4D.like(2D) = 2D
assert ak.all(v1 + v3.like(v1) == v1_v2)
assert ak.all(v3.like(v1) + v1 == v1_v2)
# 3D + 2D.like(3D) = 3D
assert ak.all(v2 + v1.like(v2) == v2_v1)
assert ak.all(v1.like(v2) + v2 == v2_v1)
# 3D + 4D.like(3D) = 3D
assert ak.all(v2 + v3.like(v2) == v2_v3)
assert ak.all(v3.like(v2) + v2 == v2_v3)
# 4D + 2D.like(4D) = 4D
assert ak.all(v3 + v1.like(v3) == v1_v3)
assert ak.all(v1.like(v3) + v3 == v1_v3)
# 4D + 3D.like(4D) = 4D
assert ak.all(v3 + v2.like(v3) == v3_v2)
assert ak.all(v2.like(v3) + v3 == v3_v2)


def test_handler_of():
numpy_vec = vector.array(
Expand Down Expand Up @@ -873,3 +850,48 @@ def test_momentum_coordinate_transforms():
assert hasattr(transformed_object, t1[2:])
assert hasattr(transformed_object, t2)
assert hasattr(transformed_object, t3)


def test_momentum_preservation():
v1 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
},
)
v2 = vector.zip(
{
"x": [10.0, 20.0, 30.0],
"y": [-10.0, 20.0, 30.0],
"z": [5.0, 1.0, 1.0],
},
)

v3 = vector.zip(
{
"px": [10.0, 20.0, 30.0],
"py": [-10.0, 20.0, 30.0],
"pz": [5.0, 1.0, 1.0],
"t": [16.0, 31.0, 46.0],
},
)

# momentum + generic = momentum
# 2D + 3D.like(2D) = 2D
assert isinstance(v1 + v2.like(v1), MomentumAwkward2D)
assert isinstance(v2.like(v1) + v1, MomentumAwkward2D)
# 2D + 4D.like(2D) = 2D
assert isinstance(v1 + v3.like(v1), MomentumAwkward2D)
assert isinstance(v3.like(v1) + v1, MomentumAwkward2D)
# 3D + 2D.like(3D) = 3D
assert isinstance(v2 + v1.like(v2), MomentumAwkward3D)
assert isinstance(v1.like(v2) + v2, MomentumAwkward3D)
# 3D + 4D.like(3D) = 3D
assert isinstance(v2 + v3.like(v2), MomentumAwkward3D)
assert isinstance(v3.like(v2) + v2, MomentumAwkward3D)
# 4D + 2D.like(4D) = 4D
assert isinstance(v3 + v1.like(v3), MomentumAwkward4D)
assert isinstance(v1.like(v3) + v3, MomentumAwkward4D)
# 4D + 3D.like(4D) = 4D
assert isinstance(v3 + v2.like(v3), MomentumAwkward4D)
assert isinstance(v2.like(v3) + v3, MomentumAwkward4D)
Loading