Skip to content

Commit ca32253

Browse files
authored
Merge pull request #94 from DanCardin/dc/view-metadata-sequence
fix: Handle view metadata sequence.
2 parents 2c329d8 + aec9ee4 commit ca32253

File tree

16 files changed

+233
-42
lines changed

16 files changed

+233
-42
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sqlalchemy-declarative-extensions"
3-
version = "0.15.2"
3+
version = "0.15.3"
44
authors = ["Dan Cardin <[email protected]>"]
55

66
description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."

src/sqlalchemy_declarative_extensions/alembic/row.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from alembic.autogenerate.api import AutogenContext
44
from alembic.operations import Operations
55
from alembic.operations.ops import UpgradeOps
6+
from sqlalchemy import MetaData
67

78
from sqlalchemy_declarative_extensions import row
89
from sqlalchemy_declarative_extensions.alembic.base import (
@@ -20,18 +21,14 @@
2021

2122

2223
def compare_rows(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _):
23-
if (
24-
autogen_context.metadata is None or autogen_context.connection is None
25-
): # pragma: no cover
24+
optional_rows: tuple[Rows, MetaData] | None = Rows.extract(autogen_context.metadata)
25+
if not optional_rows:
2626
return
2727

28-
rows: Rows | None = autogen_context.metadata.info.get("rows")
29-
if not rows:
30-
return
28+
rows, metadata = optional_rows
3129

32-
result = row.compare.compare_rows(
33-
autogen_context.connection, autogen_context.metadata, rows
34-
)
30+
assert autogen_context.connection
31+
result = row.compare.compare_rows(autogen_context.connection, metadata, rows)
3532
upgrade_ops.ops.extend(result) # type: ignore
3633

3734

src/sqlalchemy_declarative_extensions/alembic/view.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from __future__ import annotations
2+
13
from alembic.autogenerate.api import AutogenContext
24

35
from sqlalchemy_declarative_extensions.alembic.base import (
46
register_comparator_dispatcher,
57
register_renderer_dispatcher,
68
register_rewriter_dispatcher,
79
)
10+
from sqlalchemy_declarative_extensions.view.base import Views
811
from sqlalchemy_declarative_extensions.view.compare import (
912
CreateViewOp,
1013
DropViewOp,
@@ -14,13 +17,13 @@
1417
)
1518

1619

17-
def _compare_views(autogen_context, upgrade_ops, _):
18-
metadata = autogen_context.metadata
19-
views = metadata.info.get("views")
20+
def _compare_views(autogen_context: AutogenContext, upgrade_ops, _):
21+
views: Views | None = Views.extract(autogen_context.metadata)
2022
if not views:
2123
return
2224

23-
result = compare_views(autogen_context.connection, views, metadata)
25+
assert autogen_context.connection
26+
result = compare_views(autogen_context.connection, views)
2427
upgrade_ops.ops.extend(result)
2528

2629

src/sqlalchemy_declarative_extensions/dialects/postgresql/view.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import dataclass, field, replace
44
from typing import Any, Literal
55

6-
from sqlalchemy import MetaData
76
from sqlalchemy.engine import Connection, Dialect
87
from typing_extensions import override
98

@@ -121,9 +120,12 @@ def to_sql_create(self, dialect: Dialect) -> list[str]:
121120
return result
122121

123122
def normalize(
124-
self, conn: Connection, metadata: MetaData, using_connection: bool = True
123+
self,
124+
conn: Connection,
125+
naming_convention: base.NamingConvention | None,
126+
using_connection: bool = True,
125127
) -> View:
126-
instance = super().normalize(conn, metadata, using_connection)
128+
instance = super().normalize(conn, naming_convention, using_connection)
127129
return replace(
128130
instance,
129131
materialized=MaterializedOptions.from_value(self.materialized),

src/sqlalchemy_declarative_extensions/dialects/snowflake/view.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import dataclass, replace
44
from typing import Any
55

6-
from sqlalchemy import MetaData
76
from sqlalchemy.engine import Connection, Dialect
87
from typing_extensions import override
98

@@ -52,9 +51,12 @@ def to_sql_create(self, dialect: Dialect) -> list[str]:
5251
return result
5352

5453
def normalize(
55-
self, conn: Connection, metadata: MetaData, using_connection: bool = True
54+
self,
55+
conn: Connection,
56+
naming_convention: base.NamingConvention | None,
57+
using_connection: bool = True,
5658
) -> View:
57-
result = super().normalize(conn, metadata, using_connection)
59+
result = super().normalize(conn, naming_convention, using_connection)
5860
return replace(
5961
result,
6062
schema=self.schema.upper() if self.schema else None,

src/sqlalchemy_declarative_extensions/row/base.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field, replace
4-
from typing import Any, Iterable
4+
from typing import Any, Iterable, Sequence
5+
6+
from sqlalchemy import MetaData
7+
from typing_extensions import Self
58

69
from sqlalchemy_declarative_extensions.sql import split_schema
710

@@ -22,6 +25,30 @@ def coerce_from_unknown(cls, unknown: None | Iterable[Row] | Rows) -> Rows | Non
2225

2326
return None
2427

28+
@classmethod
29+
def extract(
30+
cls, metadata: MetaData | list[MetaData | None] | None
31+
) -> tuple[Self, MetaData] | None:
32+
if not isinstance(metadata, Sequence):
33+
metadata = [metadata]
34+
35+
instances: list[Self] = [
36+
m.info["rows"] for m in metadata if m and m.info.get("rows")
37+
]
38+
39+
instance_count = len(instances)
40+
if instance_count == 0:
41+
return None
42+
43+
if instance_count == 1:
44+
metadata = metadata[0]
45+
assert metadata
46+
return instances[0], metadata
47+
48+
raise NotImplementedError(
49+
"Rows is currently only supported on a single instance of MetaData. File an issue if this affects you!"
50+
)
51+
2552
def __iter__(self):
2653
yield from self.rows
2754

src/sqlalchemy_declarative_extensions/view/base.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@
44
import uuid
55
import warnings
66
from dataclasses import dataclass, field, replace
7-
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, TypeVar, cast
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
Callable,
11+
Dict,
12+
Iterable,
13+
List,
14+
Optional,
15+
Sequence,
16+
TypeVar,
17+
cast,
18+
)
819

920
from sqlalchemy import Index, MetaData, UniqueConstraint, text
1021
from sqlalchemy.engine import Connection, Dialect
@@ -29,6 +40,7 @@
2940

3041
T = TypeVar("T")
3142
ViewType = TypeVar("ViewType", "View", "DeclarativeView")
43+
NamingConvention = Dict[str, Any]
3244

3345

3446
def view(
@@ -288,12 +300,15 @@ def render_constraints(self, *, create):
288300
return result
289301

290302
def normalize(
291-
self, conn: Connection, metadata: MetaData, using_connection: bool = True
303+
self,
304+
conn: Connection,
305+
naming_convention: NamingConvention | None,
306+
using_connection: bool = True,
292307
) -> Self:
293308
constraints = None
294309
if self.constraints:
295310
constraints = [
296-
ViewIndex.from_unknown(c, self, conn.dialect, metadata)
311+
ViewIndex.from_unknown(c, self, conn.dialect, naming_convention)
297312
for c in self.constraints
298313
]
299314

@@ -365,7 +380,7 @@ def from_unknown(
365380
index: ViewIndex | Index | UniqueConstraint,
366381
source_view: View,
367382
dialect: Dialect,
368-
metadata: MetaData,
383+
naming_convention: NamingConvention | None,
369384
):
370385
if isinstance(index, ViewIndex):
371386
convention = "uq" if index.unique else "ix"
@@ -390,13 +405,13 @@ def from_unknown(
390405
if instance.name:
391406
return instance
392407

393-
naming_convention = metadata.naming_convention or DEFAULT_NAMING_CONVENTION
408+
naming_convention = naming_convention or DEFAULT_NAMING_CONVENTION # type: ignore
409+
assert naming_convention
410+
assert "ix" in naming_convention
394411
template = cast(
395412
str, naming_convention.get(convention) or naming_convention["ix"]
396413
)
397-
cd = ConventionDict(
398-
_ViewIndexAdapter(instance), source_view, metadata.naming_convention
399-
)
414+
cd = ConventionDict(_ViewIndexAdapter(instance), source_view, naming_convention)
400415
conventionalized_name = conv(template % cd)
401416

402417
try:
@@ -504,6 +519,7 @@ class Views:
504519

505520
ignore: Iterable[str] = field(default_factory=set)
506521
ignore_views: Iterable[str] = field(default_factory=set)
522+
naming_convention: NamingConvention | None = None
507523

508524
@classmethod
509525
def coerce_from_unknown(
@@ -517,6 +533,53 @@ def coerce_from_unknown(
517533

518534
return None
519535

536+
@classmethod
537+
def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | None:
538+
if not isinstance(metadata, Sequence):
539+
metadata = [metadata]
540+
541+
naming_conventions = [m.naming_convention for m in metadata if m]
542+
instances: list[Self] = [
543+
m.info["views"] for m in metadata if m and m.info.get("views")
544+
]
545+
546+
instance_count = len(instances)
547+
if instance_count == 0:
548+
return None
549+
550+
if instance_count == 1:
551+
return instances[0]
552+
553+
if not all(
554+
x.ignore_unspecified == instances[0].ignore_unspecified
555+
and x.naming_convention == instances[0].naming_convention
556+
for x in instances
557+
):
558+
raise ValueError(
559+
"All combined `Views` instances must agree on the set of settings: ignore_unspecified, naming_convention"
560+
)
561+
562+
views = [s for instance in instances for s in instance.views]
563+
ignore = [s for instance in instances for s in instance.ignore]
564+
ignore_views = [s for instance in instances for s in instance.ignore_views]
565+
566+
ignore_unspecified = instances[0].ignore_unspecified
567+
naming_convention: NamingConvention = instances[0].naming_convention # type: ignore
568+
569+
if not naming_convention:
570+
if not all(n == naming_conventions[0] for n in naming_conventions):
571+
raise ValueError("All MetaData `naming_convention`s must agree")
572+
573+
naming_convention = naming_conventions[0] # type: ignore
574+
575+
return cls(
576+
views=views,
577+
ignore_unspecified=ignore_unspecified,
578+
ignore=ignore,
579+
ignore_views=ignore_views,
580+
naming_convention=naming_convention,
581+
)
582+
520583
def append(self, view: View | DeclarativeView):
521584
self.views.append(view)
522585

src/sqlalchemy_declarative_extensions/view/compare.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from fnmatch import fnmatch
66
from typing import Union
77

8-
from sqlalchemy import MetaData
98
from sqlalchemy.engine import Connection, Dialect
109

1110
from sqlalchemy_declarative_extensions.dialects import get_view_cls, get_views
@@ -52,7 +51,6 @@ def to_sql(self, dialect: Dialect) -> list[str]:
5251
def compare_views(
5352
connection: Connection,
5453
views: Views,
55-
metadata: MetaData,
5654
normalize_with_connection: bool = True,
5755
) -> list[Operation]:
5856
if views.ignore_views:
@@ -79,7 +77,9 @@ def compare_views(
7977
removed_view_names = existing_view_names - expected_view_names
8078

8179
for view in concrete_defined_views:
82-
normalized_view = view.normalize(connection, metadata, using_connection=False)
80+
normalized_view = view.normalize(
81+
connection, views.naming_convention, using_connection=False
82+
)
8383

8484
view_name = normalized_view.qualified_name
8585

@@ -95,12 +95,16 @@ def compare_views(
9595
result.append(CreateViewOp(normalized_view))
9696
else:
9797
normalized_view = normalized_view.normalize(
98-
connection, metadata, using_connection=normalize_with_connection
98+
connection,
99+
views.naming_convention,
100+
using_connection=normalize_with_connection,
99101
)
100102

101103
existing_view = existing_views_by_name[view_name]
102104
normalized_existing_view = existing_view.normalize(
103-
connection, metadata, using_connection=normalize_with_connection
105+
connection,
106+
views.naming_convention,
107+
using_connection=normalize_with_connection,
104108
)
105109

106110
if normalized_existing_view != normalized_view:

src/sqlalchemy_declarative_extensions/view/ddl.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
def view_ddl(views: Views, view_filter: list[str] | None = None):
1212
def after_create(metadata: MetaData, connection: Connection, **_):
13-
result = compare_views(
14-
connection, views, metadata, normalize_with_connection=False
15-
)
13+
result = compare_views(connection, views, normalize_with_connection=False)
1614
for op in result:
1715
if not match_name(op.view.qualified_name, view_filter):
1816
continue

tests/examples/test_view_complex_comparison_pg/test_migrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ def test_apply_autogenerated_revision(alembic_runner: MigrationContext, alembic_
2222

2323
# Now a comparison should yield no results, because the view def has not changed.
2424
with alembic_engine.connect() as conn:
25-
result = compare_views(conn, views=Base.views, metadata=Base.metadata)
25+
result = compare_views(conn, views=Base.views)
2626
assert result == []

tests/row/test_metadata_sequence.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
import sqlalchemy
3+
4+
from sqlalchemy_declarative_extensions import (
5+
Row,
6+
Rows,
7+
declare_database,
8+
)
9+
10+
metadata1 = sqlalchemy.MetaData()
11+
metadata2 = sqlalchemy.MetaData()
12+
metadata3 = sqlalchemy.MetaData()
13+
14+
declare_database(metadata1, rows=Rows().are(Row("foo")))
15+
declare_database(metadata2, rows=Rows().are(Row("bar")))
16+
17+
18+
def test_invalid_combination():
19+
with pytest.raises(NotImplementedError):
20+
Rows.extract([metadata1, metadata1])
21+
22+
23+
def test_single():
24+
rows = Rows.extract(metadata1)
25+
assert rows
26+
assert rows[0] is metadata1.info["rows"]
27+
assert rows[1] is metadata1
28+
29+
rows = Rows.extract([metadata1, metadata3])
30+
assert rows
31+
assert rows[0] is metadata1.info["rows"]
32+
assert rows[1] is metadata1

0 commit comments

Comments
 (0)