Skip to content

Commit ba885a1

Browse files
stegbensoumith
authored andcommitted
expose bitwise operators from C/CUDA (pytorch#1556)
* fix issue pytorch#1549, expose bitwise and * expose C bitwise or of Tensor * expose C bitwise xor of Tensor * use built-in method for inplace and, or, xor * expose C bitwise lshift(ilshift) and rshift(irshift) of Tensor
1 parent ce1a0eb commit ba885a1

File tree

2 files changed

+180
-36
lines changed

2 files changed

+180
-36
lines changed

torch/csrc/generic/methods/Tensor.cwrap

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,3 +824,183 @@ PyObject * THPTensor_(copy_)(PyObject *self, PyObject *args, PyObject *kwargs)
824824
return THPCopyMethod(THTensor_(copy_functions), self, args, kwargs);
825825
END_HANDLE_TH_ERRORS
826826
}
827+
828+
[[
829+
name: __and__
830+
with_stateless: True
831+
return: argument 0
832+
options:
833+
- cname: bitand
834+
arguments:
835+
- arg: THTensor* result
836+
output: True
837+
- THTensor* self
838+
- real value
839+
- cname: cbitand
840+
arguments:
841+
- arg: THTensor* result
842+
output: True
843+
- THTensor* self
844+
- THTensor* other
845+
]]
846+
847+
[[
848+
name: __iand__
849+
with_stateless: True
850+
return: argument 0
851+
options:
852+
- cname: bitand
853+
arguments:
854+
- THTensor* self
855+
- THTensor* self
856+
- real value
857+
- cname: cbitand
858+
arguments:
859+
- THTensor* self
860+
- THTensor* self
861+
- THTensor* other
862+
]]
863+
864+
[[
865+
name: __or__
866+
with_stateless: True
867+
return: argument 0
868+
options:
869+
- cname: bitor
870+
arguments:
871+
- arg: THTensor* result
872+
output: True
873+
- THTensor* self
874+
- real value
875+
- cname: cbitor
876+
arguments:
877+
- arg: THTensor* result
878+
output: True
879+
- THTensor* self
880+
- THTensor* other
881+
]]
882+
883+
[[
884+
name: __ior__
885+
with_stateless: True
886+
return: argument 0
887+
options:
888+
- cname: bitor
889+
arguments:
890+
- THTensor* self
891+
- THTensor* self
892+
- real value
893+
- cname: cbitor
894+
arguments:
895+
- THTensor* self
896+
- THTensor* self
897+
- THTensor* other
898+
]]
899+
900+
[[
901+
name: __xor__
902+
with_stateless: True
903+
return: argument 0
904+
options:
905+
- cname: bitxor
906+
arguments:
907+
- arg: THTensor* result
908+
output: True
909+
- THTensor* self
910+
- real value
911+
- cname: cbitxor
912+
arguments:
913+
- arg: THTensor* result
914+
output: True
915+
- THTensor* self
916+
- THTensor* other
917+
]]
918+
919+
[[
920+
name: __ixor__
921+
with_stateless: True
922+
return: argument 0
923+
options:
924+
- cname: bitxor
925+
arguments:
926+
- THTensor* self
927+
- THTensor* self
928+
- real value
929+
- cname: cbitxor
930+
arguments:
931+
- THTensor* self
932+
- THTensor* self
933+
- THTensor* other
934+
]]
935+
936+
[[
937+
name: __lshift__
938+
with_stateless: True
939+
return: argument 0
940+
options:
941+
- cname: lshift
942+
arguments:
943+
- arg: THTensor* result
944+
output: True
945+
- THTensor* self
946+
- real value
947+
- cname: clshift
948+
arguments:
949+
- arg: THTensor* result
950+
output: True
951+
- THTensor* self
952+
- THTensor* other
953+
]]
954+
955+
[[
956+
name: __ilshift__
957+
with_stateless: True
958+
return: argument 0
959+
options:
960+
- cname: lshift
961+
arguments:
962+
- THTensor* self
963+
- THTensor* self
964+
- real value
965+
- cname: clshift
966+
arguments:
967+
- THTensor* self
968+
- THTensor* self
969+
- THTensor* other
970+
]]
971+
972+
[[
973+
name: __rshift__
974+
with_stateless: True
975+
return: argument 0
976+
options:
977+
- cname: rshift
978+
arguments:
979+
- arg: THTensor* result
980+
output: True
981+
- THTensor* self
982+
- real value
983+
- cname: crshift
984+
arguments:
985+
- arg: THTensor* result
986+
output: True
987+
- THTensor* self
988+
- THTensor* other
989+
]]
990+
991+
[[
992+
name: __irshift__
993+
with_stateless: True
994+
return: argument 0
995+
options:
996+
- cname: rshift
997+
arguments:
998+
- THTensor* self
999+
- THTensor* self
1000+
- real value
1001+
- cname: crshift
1002+
arguments:
1003+
- THTensor* self
1004+
- THTensor* self
1005+
- THTensor* other
1006+
]]

torch/tensor.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -345,42 +345,6 @@ def __ge__(self, other):
345345
return self.ge(other)
346346

347347
# TODO: add native add or and xor in the libs
348-
def __and__(self, other):
349-
if (type(self).__name__ != 'ByteTensor' or
350-
type(other).__name__ != 'ByteTensor'):
351-
raise RuntimeError('logical operations are supported on ByteTensors only')
352-
return (self + other).eq(2)
353-
354-
def __or__(self, other):
355-
if (type(self).__name__ != 'ByteTensor' or
356-
type(other).__name__ != 'ByteTensor'):
357-
raise RuntimeError('logical operations are supported on ByteTensors only')
358-
return (self + other).gt(0)
359-
360-
def __xor__(self, other):
361-
if (type(self).__name__ != 'ByteTensor' or
362-
type(other).__name__ != 'ByteTensor'):
363-
raise RuntimeError('logical operations are supported on ByteTensors only')
364-
return (self + other).eq(1)
365-
366-
def __iand__(self, other):
367-
if (type(self).__name__ != 'ByteTensor' or
368-
type(other).__name__ != 'ByteTensor'):
369-
raise RuntimeError('logical operations are supported on ByteTensors only')
370-
return self.mul_(other)
371-
372-
def __ior__(self, other):
373-
if (type(self).__name__ != 'ByteTensor' or
374-
type(other).__name__ != 'ByteTensor'):
375-
raise RuntimeError('logical operations are supported on ByteTensors only')
376-
return self.copy_((self + other).gt(0))
377-
378-
def __ixor__(self, other):
379-
if (type(self).__name__ != 'ByteTensor' or
380-
type(other).__name__ != 'ByteTensor'):
381-
raise RuntimeError('logical operations are supported on ByteTensors only')
382-
return self.copy_((self + other).eq(1))
383-
384348
def __invert__(self):
385349
if type(self).__name__ != 'ByteTensor':
386350
raise RuntimeError('logical operations are supported on ByteTensors only')

0 commit comments

Comments
 (0)