Skip to content

Commit a6506bd

Browse files
committed
Type sig for @generate
1 parent 4d4f0ac commit a6506bd

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

src/parsy/__init__.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import re
1010
from dataclasses import dataclass
1111
from functools import reduce, wraps
12-
from typing import Any, Callable, FrozenSet, Generic, Optional, TypeVar, Union
12+
from typing import Any, Callable, FrozenSet, Generator, Generic, Optional, TypeVar, Union
1313

1414

1515
from .version import __version__ # noqa: F401
@@ -359,11 +359,25 @@ def __lshift__(self: Parser[OUT1], other: Parser) -> Parser[OUT1]:
359359
return self.skip(other)
360360

361361

362-
# combinator syntax
363-
def generate(fn):
362+
# TODO:
363+
# I think @generate is unfixable. It's not surprising, because
364+
# we are doing something genuninely unusual with generator functions.
365+
366+
# The return value of a `@generate` parser is now OK.
367+
368+
# But we have no type checking within a user's @generate function.
369+
370+
# The big issue is that each `val = yield parser` inside a @generate parser has
371+
# a different type, and we'd like those to be typed checked. But the
372+
# `Generator[...]` expects a homogeneous stream of yield and send types,
373+
# whereas we have pairs of yield/send types which need to match within the
374+
# pair, but each pair can be completely different from the next in the stream
375+
376+
377+
def generate(fn: Callable[[], Generator[Parser[Any], Any, OUT]]) -> Parser[OUT]:
364378
@Parser
365379
@wraps(fn)
366-
def generated(stream, index):
380+
def generated(stream: str, index: int) -> Result[OUT]:
367381
# start up the generator
368382
iterator = fn()
369383

@@ -379,9 +393,6 @@ def generated(stream, index):
379393
index = result.index
380394
except StopIteration as stop:
381395
returnVal = stop.value
382-
if isinstance(returnVal, Parser):
383-
return returnVal(stream, index).aggregate(result)
384-
385396
return Result.success(index, returnVal).aggregate(result)
386397

387398
return generated

tests/test_parsy.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
except ImportError:
55
enum = None
66
import re
7+
from typing import Generator
78
import unittest
8-
from collections import namedtuple
9-
from datetime import date
9+
10+
from typing import Any
1011

1112
from parsy import (
13+
Parser,
1214
ParseError,
1315
any_char,
1416
char_from,
@@ -134,7 +136,7 @@ def test_generate(self):
134136
x = y = None
135137

136138
@generate
137-
def xy():
139+
def xy() -> Generator[Parser[Any], Any, int]:
138140
nonlocal x
139141
nonlocal y
140142
x = yield string("x")
@@ -145,14 +147,6 @@ def xy():
145147
self.assertEqual(x, "x")
146148
self.assertEqual(y, "y")
147149

148-
def test_generate_return_parser(self):
149-
@generate
150-
def example():
151-
yield string("x")
152-
return string("y")
153-
154-
self.assertEqual(example.parse("xy"), "y")
155-
156150
def test_mark(self):
157151
parser = (letter.many().mark() << string("\n")).many()
158152

@@ -189,7 +183,7 @@ def test_multiple_failures(self):
189183

190184
def test_generate_backtracking(self):
191185
@generate
192-
def xy():
186+
def xy() -> Generator[Parser[Any], Any, None]:
193187
yield string("x")
194188
yield string("y")
195189
assert False
@@ -480,7 +474,7 @@ def test_decimal_digit(self):
480474

481475
def test_line_info(self):
482476
@generate
483-
def foo():
477+
def foo() -> Generator[Any, Any, tuple[str, tuple[int, int]]]:
484478
i = yield line_info
485479
l = yield any_char
486480
return (l, i)

0 commit comments

Comments
 (0)