Skip to content

Commit e9f9dd7

Browse files
authored
Merge branch 'master' into fix-3.8-warning
2 parents 83c9dda + 486119e commit e9f9dd7

File tree

2 files changed

+53
-12
lines changed

2 files changed

+53
-12
lines changed

fn/func.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from functools import partial, update_wrapper, wraps
2-
from inspect import getargspec
2+
from sys import version_info
3+
4+
5+
_has_type_hint_support = version_info[:2] >= (3, 5)
36

47

58
def identity(arg):
@@ -69,6 +72,17 @@ def curried(func):
6972
>>> sum5(1, 2, 3)(4, 5)
7073
15
7174
"""
75+
76+
def _args_len(func):
77+
if _has_type_hint_support:
78+
from inspect import signature
79+
args = signature(func).parameters
80+
else:
81+
from inspect import getargspec
82+
args = getargspec(func).args
83+
84+
return len(args)
85+
7286
@wraps(func)
7387
def _curried(*args, **kwargs):
7488
f = func
@@ -78,9 +92,7 @@ def _curried(*args, **kwargs):
7892
count += len(f.args)
7993
f = f.func
8094

81-
spec = getargspec(f)
82-
83-
if count == len(spec.args) - len(args):
95+
if count == _args_len(f) - len(args):
8496
return func(*args, **kwargs)
8597

8698
para_func = partial(func, *args, **kwargs)

tests/test_curry.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
import unittest
22

3-
from fn.func import curried
3+
from fn.func import curried, _has_type_hint_support
44

55

66
class Curriedtest(unittest.TestCase):
77

8+
def _assert_instance(self, expected, acutal):
9+
self.assertEqual(expected.__module__, acutal.__module__)
10+
self.assertEqual(expected.__name__, acutal.__name__)
11+
12+
if _has_type_hint_support:
13+
self.assertEqual(expected.__annotations__, acutal.__annotations__)
14+
815
def test_curried_wrapper(self):
916

1017
@curried
@@ -15,16 +22,38 @@ def _child(a, b, c, d):
1522
def _moma(a, b):
1623
return _child(a, b)
1724

18-
def _assert_instance(expected, acutal):
19-
self.assertEqual(expected.__module__, acutal.__module__)
20-
self.assertEqual(expected.__name__, acutal.__name__)
21-
2225
res1 = _moma(1)
23-
_assert_instance(_moma, res1)
26+
self._assert_instance(_moma, res1)
27+
res2 = res1(2)
28+
self._assert_instance(_child, res2)
29+
res3 = res2(3)
30+
self._assert_instance(_child, res3)
31+
res4 = res3(4)
32+
33+
self.assertEqual(res4, 10)
34+
35+
@unittest.skipIf(not _has_type_hint_support, "Type hint aren't supported")
36+
def test_curried_with_annotations_when_they_are_supported(self):
37+
38+
def _custom_sum(a, b, c, d):
39+
return a + b + c + d
40+
41+
_custom_sum.__annotations__ = {
42+
'a': int,
43+
'b': int,
44+
'c': int,
45+
'd': int,
46+
'return': int
47+
}
48+
49+
custom_sum = curried(_custom_sum)
50+
51+
res1 = custom_sum(1)
52+
self._assert_instance(custom_sum, res1)
2453
res2 = res1(2)
25-
_assert_instance(_child, res2)
54+
self._assert_instance(custom_sum, res2)
2655
res3 = res2(3)
27-
_assert_instance(_child, res3)
56+
self._assert_instance(custom_sum, res3)
2857
res4 = res3(4)
2958

3059
self.assertEqual(res4, 10)

0 commit comments

Comments
 (0)