Skip to content

Commit 402176e

Browse files
committed
Working reordering
1 parent 48ee9ea commit 402176e

File tree

5 files changed

+95
-17
lines changed

5 files changed

+95
-17
lines changed

django_pgtree/fields.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ def to_python(self, value):
2424
raise ValueError("Don't know how to handle {!r}".format(value))
2525

2626
def get_prep_value(self, value):
27-
if isinstance(value, str):
27+
if isinstance(value, str) or value is None:
2828
return value
2929
return '.'.join(value)
3030

3131
def from_db_value(self, value, expression, connection):
32+
if not value:
33+
return []
3234
return value.split('.')
3335

3436

django_pgtree/models.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from .fields import LtreeField
88

9+
GAP = 1_000_000_000
10+
911

1012
class LtreeConcat(models.Func):
1113
arg_joiner = '||'
@@ -20,8 +22,8 @@ class Text2Ltree(models.Func):
2022

2123

2224
class TreeNode(models.Model):
23-
__old_tree_path = None
24-
tree_path = LtreeField()
25+
__new_parent = None
26+
tree_path = LtreeField(unique=True)
2527

2628
class Meta:
2729
abstract = True
@@ -30,14 +32,15 @@ class Meta:
3032

3133
def __init__(self, *args, parent=None, **kwargs):
3234
if parent is not None:
33-
kwargs['tree_path'] = [*parent.tree_path,
34-
get_random_string(length=32)]
35+
self.__new_parent = parent
3536
super().__init__(*args, **kwargs)
3637

3738
@property
3839
def parent(self):
40+
if self.__new_parent is not None:
41+
return self.__new_parent
3942
parent_path = self.tree_path[:-
40-
1] # pylint: disable=unsubscriptable-object
43+
1] # pylint: disable=unsubscriptable-object
4144
return self.__class__.objects.get(tree_path=parent_path)
4245

4346
@parent.setter
@@ -46,22 +49,54 @@ def parent(self, new_parent):
4649
raise ValueError(
4750
"Parent node must be saved before receiving children")
4851
# Replace our tree_path with a new one that has our new parent's
49-
self.__old_tree_path = self.tree_path
50-
self.tree_path = [*new_parent.tree_path, get_random_string(length=32)]
52+
self.__new_parent = new_parent
53+
54+
def __next_tree_path_qx(self, prefix=None):
55+
if prefix is None:
56+
prefix = []
57+
58+
# These are all the siblings of the target position, in reverse tree order.
59+
# If we don't have a prefix, this will be all root nodes.
60+
sibling_queryset = self.__class__.objects.filter(tree_path__matches_lquery=[*prefix, '*{1}']).order_by('-tree_path')
61+
# This query expression is the full ltree of the last sibling by tree order.
62+
last_sibling_tree_path = models.Subquery(sibling_queryset.values('tree_path')[:1])
63+
64+
# Django doesn't allow the use of column references in an INSERT statement,
65+
# because it makes the assumption that they refer to columns in the
66+
# to-be-inserted row, the values for which aren't yet known.
67+
# Unfortunately, this means we can't use a subquery that refers to column
68+
# values anywhere internally, even though the columns it refers to are subquery
69+
# result columns. To get around this, we override the contains_column_references
70+
# property on the subquery with a static False, so that Django's check doesn't
71+
# cross the subquery boundary.
72+
last_sibling_tree_path.contains_column_references = False
73+
74+
# This query expression is the rightmost component of that ltree. The double
75+
# cast is because PostgreSQL doesn't let you cast directly from ltree to bigint.
76+
last_sibling_last_value = f.Cast(f.Cast(Subpath(last_sibling_tree_path, -1), models.CharField()), models.BigIntegerField())
77+
# This query expression is an ltree containing that value, plus GAP, or just
78+
# GAP if there is no existing siblings. Again, we need to double cast.
79+
new_last_value = Text2Ltree(f.Cast(f.Coalesce(last_sibling_last_value, 0) + (GAP), models.CharField()))
80+
81+
# If we have a prefix, we prepend that to the resulting ltree.
82+
if not prefix:
83+
return new_last_value
84+
return LtreeConcat(models.Value('.'.join(prefix)), new_last_value)
5185

5286
def save(self, *args, **kwargs): # pylint: disable=arguments-differ
5387
tree_path_needs_refresh = False
88+
old_tree_path = None
89+
90+
if self.__new_parent is not None:
91+
tree_path_needs_refresh = True
92+
old_tree_path = self.tree_path or None
93+
self.tree_path = self.__next_tree_path_qx(self.__new_parent.tree_path)
5494
if not self.tree_path:
5595
tree_path_needs_refresh = True
56-
# Ensure that we have a tree_path set. We set a random one at this point,
57-
# because we don't know whether this node will become the parent of other
58-
# nodes down the track.
59-
largest_ltree = models.Subquery(self.__class__.objects.order_by('-tree_path').values('tree_path')[:1])
60-
largest_ltree_root_num = f.Cast(f.Cast(Subpath(largest_ltree, 0, 1), models.CharField()), models.BigIntegerField())
61-
self.tree_path = Text2Ltree(f.Cast(f.Coalesce(largest_ltree_root_num, -2**32) + (2**32), models.CharField()))
96+
self.tree_path = self.__next_tree_path_qx()
6297

6398
# If we haven't changed the parent, save as normal.
64-
if self.__old_tree_path is None:
99+
if old_tree_path is None:
65100
rv = super().save(*args, **kwargs)
66101

67102
# If we have, use a transaction to avoid other contexts seeing the intermediate
@@ -71,18 +106,21 @@ def save(self, *args, **kwargs): # pylint: disable=arguments-differ
71106
rv = super().save(*args, **kwargs)
72107
# Move all of our descendants along with us, by substituting our old ltree
73108
# prefix with our new one, in every descendant that has that prefix.
109+
self.refresh_from_db(fields=('tree_path',))
110+
tree_path_needs_refresh = False
74111
self.__class__.objects.filter(
75-
tree_path__descendant_of=self.__old_tree_path
112+
tree_path__descendant_of=old_tree_path
76113
).update(
77114
tree_path=LtreeConcat(
78115
models.Value('.'.join(self.tree_path)),
79-
Subpath(models.F('tree_path'), len(self.__old_tree_path)),
116+
Subpath(models.F('tree_path'), len(old_tree_path)),
80117
)
81118
)
82119

83120
if tree_path_needs_refresh:
84121
self.refresh_from_db(fields=('tree_path',))
85122

123+
print('for object {!r}, old_tree_path is {!r}, tree_path is {!r}'.format(self, old_tree_path, self.tree_path))
86124
return rv
87125

88126
@property

django_pgtree/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,26 @@ def test_family(animal):
4444
def test_reparent(animal):
4545
marsupial = T.objects.get(name='Marsupial')
4646
mammal = T.objects.get(name='Mammal')
47+
48+
dog = T.objects.get(name='Dog')
49+
dog_tree_path = dog.tree_path
50+
plant = T.objects.get(name='Plant')
51+
plant_tree_path = plant.tree_path
52+
4753
marsupial.parent = mammal
4854
marsupial.save()
55+
4956
assert marsupial.tree_path[:2] == mammal.tree_path
5057
koala = T.objects.get(name='Koala')
5158
assert koala.parent == marsupial
5259
assert koala.tree_path[:2] == mammal.tree_path
5360
assert mammal in koala.ancestors
5461

62+
dog.refresh_from_db()
63+
assert dog.tree_path == dog_tree_path
64+
plant.refresh_from_db()
65+
assert plant.tree_path == plant_tree_path
66+
5567

5668
def test_top_level_ordering(animal):
5769
all_l = list(T.objects.all())
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Generated by Django 2.1.2 on 2018-10-11 02:29
2+
3+
from django.db import migrations
4+
import django_pgtree.fields
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('testapp', '0001_initial'),
11+
]
12+
13+
operations = [
14+
migrations.AlterModelOptions(
15+
name='testmodel',
16+
options={'ordering': ('tree_path',)},
17+
),
18+
migrations.AlterField(
19+
model_name='testmodel',
20+
name='tree_path',
21+
field=django_pgtree.fields.LtreeField(unique=True),
22+
),
23+
]

testproject/testapp/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@
55

66
class TestModel(TreeNode):
77
name = models.CharField(max_length=128)
8+
9+
def __str__(self):
10+
return self.name

0 commit comments

Comments
 (0)