Skip to content

Commit 3b8836e

Browse files
authored
Merge pull request #4 from jimustafa/develop
separate raw parsers and add schema validation
2 parents 84a33e4 + 54ccc9d commit 3b8836e

File tree

8 files changed

+104
-28
lines changed

8 files changed

+104
-28
lines changed

requirements/dev-requirements.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ mkdocs-material
22
mkdocstrings[python]
33
pip-tools
44
pre-commit
5+
pydantic
56
pytest

requirements/dev-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ pre-commit==2.17.0
9696
# via -r dev-requirements.in
9797
py==1.11.0
9898
# via pytest
99+
pydantic==1.9.1
100+
# via -r dev-requirements.in
99101
pygments==2.12.0
100102
# via mkdocs-material
101103
pymdown-extensions==9.4
@@ -131,6 +133,7 @@ tomli==2.0.1
131133
typing-extensions==4.1.1
132134
# via
133135
# importlib-metadata
136+
# pydantic
134137
# pytkdocs
135138
virtualenv==20.13.1
136139
# via pre-commit

src/w90io/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def parse_win(args):
2323
'blocks': blocks,
2424
})
2525
else:
26-
parsed_win = w90io.parse_win(contents)
26+
parsed_win = w90io.parse_win_raw(contents)
2727
if args.parameters:
2828
pp.pprint({
2929
parameter: parsed_win['parameters'][parameter]
@@ -71,7 +71,7 @@ def parse_nnkp(args):
7171
'blocks': blocks,
7272
})
7373
else:
74-
parsed_nnkp = w90io.parse_nnkp(contents)
74+
parsed_nnkp = w90io.parse_nnkp_raw(contents)
7575
if args.parameters:
7676
pp.pprint({
7777
parameter: parsed_nnkp['parameters'][parameter]

src/w90io/_nnkp.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from . import _core
77

88

9-
__all__ = ['parse_nnkp']
9+
__all__ = ['parse_nnkp_raw']
1010

1111

1212
patterns = {
@@ -38,13 +38,12 @@ def parse_lattice(string: str) -> dict:
3838
match = patterns['lattice_vectors'].search(string)
3939

4040
if match is not None:
41-
v1 = np.fromstring(match.group('v1'), sep=' ')
42-
v2 = np.fromstring(match.group('v2'), sep=' ')
43-
v3 = np.fromstring(match.group('v3'), sep=' ')
41+
v1 = [float(x) for x in match.group('v1').split()]
42+
v2 = [float(x) for x in match.group('v2').split()]
43+
v3 = [float(x) for x in match.group('v3').split()]
4444

4545
return {
4646
'v1': v1, 'v2': v2, 'v3': v3,
47-
'lattice_vectors': np.array([v1, v2, v3]),
4847
}
4948
else:
5049
return None
@@ -55,7 +54,6 @@ def parse_direct_lattice(string: str) -> dict:
5554

5655
return {
5756
'a1': lattice['v1'], 'a2': lattice['v2'], 'a3': lattice['v3'],
58-
'lattice_vectors': lattice['lattice_vectors']
5957
}
6058

6159

@@ -64,15 +62,16 @@ def parse_reciprocal_lattice(string: str) -> dict:
6462

6563
return {
6664
'b1': lattice['v1'], 'b2': lattice['v2'], 'b3': lattice['v3'],
67-
'lattice_vectors': lattice['lattice_vectors']
6865
}
6966

7067

7168
def parse_kpoints(string: str) -> dict:
7269
match = patterns['kpoints'].search(string)
7370

7471
return {
75-
'kpoints': np.fromstring(match.group('kpoints'), sep='\n').reshape((len(match.group('kpoints').splitlines()), -1))[:, :3]
72+
'kpoints': [
73+
[float(x) for x in line.split()] for line in match.group('kpoints').splitlines()
74+
]
7675
}
7776

7877

@@ -119,19 +118,17 @@ def parse_spinor_projections(string: str) -> dict:
119118
def parse_exclude_bands(string: str) -> dict:
120119
match = patterns['exclude_bands'].search(string)
121120

122-
exclude_bands = np.fromstring(match.group('exclude_bands'), sep='\n', dtype=int)
123-
124-
if exclude_bands.size > 0:
121+
if match is not None:
125122
return {
126-
'exclude_bands': exclude_bands
123+
'exclude_bands': [int(line) for line in match.group('exclude_bands').splitlines()]
127124
}
128125
else:
129126
return {
130127
'exclude_bands': None
131128
}
132129

133130

134-
def parse_nnkp(string: str) -> dict:
131+
def parse_nnkp_raw(string: str) -> dict:
135132
"""
136133
Parse NNKP
137134

src/w90io/_schema.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
import typing
3+
4+
import pydantic
5+
6+
7+
class UnitCell(pydantic.BaseModel):
8+
units: typing.Optional[str]
9+
a1: list
10+
a2: list
11+
a3: list
12+
13+
14+
class DirectLattice(pydantic.BaseModel):
15+
a1: list
16+
a2: list
17+
a3: list
18+
19+
20+
class ReciprocalLattice(pydantic.BaseModel):
21+
b1: list
22+
b2: list
23+
b3: list
24+
25+
26+
class Atom(pydantic.BaseModel):
27+
species: str
28+
basis_vector: list
29+
30+
31+
class Atoms(pydantic.BaseModel):
32+
units: typing.Optional[str]
33+
atoms: typing.List[Atoms]
34+
35+
36+
class Projections(pydantic.BaseModel):
37+
units: typing.Optional[str]
38+
projections: typing.List[str]
39+
40+
41+
class Kpoints(pydantic.BaseModel):
42+
kpoints: typing.List[list]
43+
44+
45+
class ExcludeBands(pydantic.BaseModel):
46+
exclude_bands: typing.Optional[typing.List[int]]
47+
48+
49+
class WIN(pydantic.BaseModel):
50+
comments: typing.List[str]
51+
parameters: dict
52+
blocks: dict
53+
unit_cell_cart: typing.Optional[UnitCell]
54+
# atoms_frac: Atoms
55+
projections: Projections
56+
kpoints: Kpoints
57+
58+
59+
class NNKP(pydantic.BaseModel):
60+
comments: typing.List[str]
61+
parameters: dict
62+
blocks: dict
63+
direct_lattice: DirectLattice
64+
reciprocal_lattice: ReciprocalLattice
65+
kpoints: Kpoints
66+
exclude_bands: ExcludeBands

src/w90io/_win.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33

44
from . import _core
55

6-
import numpy as np
76

8-
9-
__all__ = ['parse_win']
7+
__all__ = ['parse_win_raw']
108

119

1210
patterns = {
@@ -34,14 +32,13 @@ def parse_unit_cell(string: str) -> dict:
3432
match = patterns['unit_cell'].search(string)
3533

3634
if match is not None:
37-
a1 = np.fromstring(match.group('a1'), sep=' ')
38-
a2 = np.fromstring(match.group('a2'), sep=' ')
39-
a3 = np.fromstring(match.group('a3'), sep=' ')
35+
a1 = [float(x) for x in match.group('a1').split()]
36+
a2 = [float(x) for x in match.group('a2').split()]
37+
a3 = [float(x) for x in match.group('a3').split()]
4038

4139
return {
4240
'units': match.group('units'),
4341
'a1': a1, 'a2': a2, 'a3': a3,
44-
'lattice_vectors': np.array([a1, a2, a3]),
4542
}
4643
else:
4744
return None
@@ -56,7 +53,7 @@ def parse_atoms(string: str) -> dict:
5653
'atoms': [
5754
{
5855
'species': line.split()[0],
59-
'basis_vector': np.fromiter(map(float, line.split()[1:]), dtype=float),
56+
'basis_vector': [float(x) for x in line.split()[1:]],
6057
}
6158
for line in match.group('atoms').splitlines()
6259
]
@@ -79,11 +76,13 @@ def parse_projections(string: str) -> dict:
7976

8077
def parse_kpoints(string: str) -> dict:
8178
return {
82-
'kpoints': np.fromstring(string, sep='\n').reshape((len(string.splitlines()), -1))[:, :3]
79+
'kpoints': [
80+
[float(x) for x in line.split()] for line in string.splitlines()
81+
]
8382
}
8483

8584

86-
def parse_win(string: str) -> dict:
85+
def parse_win_raw(string: str) -> dict:
8786
"""
8887
Parse WIN
8988

tests/test_nnkp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
import w90io
5+
import w90io._schema
66

77

88
@pytest.mark.parametrize('example', [f'example{i:02d}' for i in [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 13, 17, 18, 19, 20]])
@@ -11,6 +11,11 @@ def test_parse_nnkp(wannier90, example):
1111
contents = fh.read()
1212

1313
try:
14-
w90io.parse_nnkp(contents)
14+
parsed_nnkp = w90io.parse_nnkp_raw(contents)
15+
except Exception:
16+
assert False
17+
18+
try:
19+
w90io._schema.NNKP.parse_obj(parsed_nnkp)
1520
except Exception:
1621
assert False

tests/test_win.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ def test_parse_win(wannier90, example):
1111
contents = fh.read()
1212

1313
try:
14-
w90io.parse_win(contents)
14+
parsed_win = w90io.parse_win_raw(contents)
15+
except Exception:
16+
assert False
17+
18+
try:
19+
w90io._schema.WIN.parse_obj(parsed_win)
1520
except Exception:
1621
assert False

0 commit comments

Comments
 (0)