Skip to content
This repository was archived by the owner on Mar 23, 2023. It is now read-only.

Import wildcard #263

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
12 changes: 7 additions & 5 deletions compiler/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,15 @@ def visit_Import(self, node):
self.block.bind_var(self.writer, asname, mod.expr)

def visit_ImportFrom(self, node):
# Wildcard imports are not yet supported.
self._write_py_context(node.lineno)
for alias in node.names:
if alias.name == '*':
msg = 'wildcard member import is not implemented: from %s import %s' % (
node.module, alias.name)
raise util.ParseError(node, msg)
self._write_py_context(node.lineno)
module_name = node.module

with self._import(module_name, module_name.count('.')) as module:
self.writer.write_checked_call1(
'πg.LoadMembers(πF, {})', module.expr)
return
if node.module.startswith(_NATIVE_MODULE_PREFIX):
values = [alias.name for alias in node.names]
with self._import_native(node.module, values) as mod:
Expand Down
14 changes: 6 additions & 8 deletions compiler/stmt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,12 @@ def testImportFromFutureParseError(self):
self.assertRaisesRegexp(util.ParseError, want_regexp,
stmt.import_from_future, node)

def testImportWildcardMemberRaises(self):
regexp = r'wildcard member import is not implemented: from foo import *'
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
'from foo import *')
regexp = (r'wildcard member import is not '
r'implemented: from __go__.foo import *')
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
'from __go__.foo import *')
def testImportWildcard(self):
result = _GrumpRun(textwrap.dedent("""\
from time import *
print sleep"""))
self.assertEqual(0, result[0])
self.assertIn('<function sleep at', result[1])

def testVisitFuture(self):
testcases = [
Expand Down
52 changes: 52 additions & 0 deletions runtime/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,58 @@ func ImportNativeModule(f *Frame, name string, members map[string]*Object) (*Obj
return prev, nil
}

// LoadMembers scans over all the members in module
// and populates globals with them, taking __all__ into
// account.
func LoadMembers(f *Frame, module *Object) *BaseException {
allAttr, raised := GetAttr(f, module, NewStr("__all__"), nil)
if raised != nil && !raised.isInstance(AttributeErrorType) {
return raised
}
f.RestoreExc(nil, nil)

if raised == nil {
raised = loadMembersFromIterable(f, module, allAttr, nil)
if raised != nil {
return raised
}
return nil
}

// Fall back on __dict__
dictAttr := module.dict.ToObject()
raised = loadMembersFromIterable(f, module, dictAttr, func(key *Object) bool {
return strings.HasPrefix(toStrUnsafe(key).value, "_")
})
if raised != nil {
return raised
}
return nil
}

func loadMembersFromIterable(f *Frame, module, iterable *Object, filterF func(*Object) bool) *BaseException {
globals := f.Globals()
raised := seqForEach(f, iterable, func(memberName *Object) *BaseException {
if !memberName.isInstance(StrType) {
errorMessage := fmt.Sprintf("attribute name must be string, not '%v'", memberName.typ.Name())
return f.RaiseType(AttributeErrorType, errorMessage)
}
member, raised := GetAttr(f, module, toStrUnsafe(memberName), nil)
if raised != nil {
return raised
}
if filterF != nil && filterF(memberName) {
return nil
}
raised = globals.SetItem(f, memberName, member)
if raised != nil {
return raised
}
return nil
})
return raised
}

// newModule creates a new Module object with the given fully qualified name
// (e.g a.b.c) and its corresponding Python filename.
func newModule(name, filename string) *Module {
Expand Down
63 changes: 63 additions & 0 deletions runtime/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,69 @@ func TestImportNativeModule(t *testing.T) {
}
}

func TestLoadMembers(t *testing.T) {
var1 := NewStr("var1")
var2 := NewStr("_var2")
var3 := NewStr("var3")
nameAttr := NewStr("__name__")
allAttr := NewStr("__all__")
val1 := NewStr("val1")
val2 := NewStr("val2")
val3 := NewStr("val3")
nameValue := NewStr("foo")
allValue := newTestList(var1, var2)
invalidMembers := newTestList(NewInt(1))
invalidMemberName := NewInt(1)
nonIterableValue := NewInt(1)

allDefinedDict := newTestDict(var1, val1, var2, val2, var3, val3, nameAttr, nameValue, allAttr, allValue)
allDefinedModule := &Module{Object: Object{typ: testModuleType, dict: allDefinedDict}}
allUndefinedDict := newTestDict(var1, val1, var2, val2, var3, val3, nameAttr, nameValue)
allUndefinedModule := &Module{Object: Object{typ: testModuleType, dict: allUndefinedDict}}
allInvalidDict := newTestDict(allAttr, invalidMembers)
allInvalidModule := &Module{Object: Object{typ: testModuleType, dict: allInvalidDict}}
packageNamesInvalidDict := newTestDict(invalidMemberName, val1)
packageNamesInvalidModule := &Module{Object: Object{typ: testModuleType, dict: packageNamesInvalidDict}}
allNonIterableDict := newTestDict(allAttr, nonIterableValue)
allNonIterableModule := &Module{Object: Object{typ: testModuleType, dict: allNonIterableDict}}

fun := wrapFuncForTest(func(f *Frame, module *Module) (*Dict, *BaseException) {
f.globals = NewDict()
raised := LoadMembers(f, module.ToObject())
if raised != nil {
return nil, raised
}
return f.Globals(), nil
})
cases := []invokeTestCase{
{
args: wrapArgs(allDefinedModule),
want: newTestDict(var1, val1, var2, val2).ToObject(),
},
{
args: wrapArgs(allUndefinedModule),
want: newTestDict(var1, val1, var3, val3).ToObject(),
},
{
args: wrapArgs(allInvalidModule),
wantExc: mustCreateException(AttributeErrorType, "attribute name must be string, not 'int'"),
},
{
args: wrapArgs(packageNamesInvalidModule),
wantExc: mustCreateException(AttributeErrorType, "attribute name must be string, not 'int'"),
},
{
args: wrapArgs(allNonIterableModule),
wantExc: mustCreateException(TypeErrorType, "'int' object is not iterable"),
},
}
for _, cas := range cases {
if err := runInvokeTestCase(fun, &cas); err != "" {
t.Error(err)
}
}
}

func TestModuleGetNameAndFilename(t *testing.T) {
fun := wrapFuncForTest(func(f *Frame, m *Module) (*Tuple, *BaseException) {
name, raised := m.GetName(f)
Expand Down