Skip to content

Commit 48ee9ea

Browse files
committed
Persistent ordering for top level tree nodes
1 parent b72e4cd commit 48ee9ea

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

django_pgtree/models.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.contrib.postgres.indexes import GistIndex
22
from django.db import models
3+
from django.db.models import functions as f
34
from django.db.transaction import atomic
45
from django.utils.crypto import get_random_string
56

@@ -14,6 +15,9 @@ class LtreeConcat(models.Func):
1415
class Subpath(models.Func):
1516
function = 'subpath'
1617

18+
class Text2Ltree(models.Func):
19+
function = 'text2ltree'
20+
1721

1822
class TreeNode(models.Model):
1923
__old_tree_path = None
@@ -22,6 +26,7 @@ class TreeNode(models.Model):
2226
class Meta:
2327
abstract = True
2428
indexes = (GistIndex(fields=['tree_path'], name='tree_path_idx'),)
29+
ordering = ('tree_path',)
2530

2631
def __init__(self, *args, parent=None, **kwargs):
2732
if parent is not None:
@@ -45,31 +50,40 @@ def parent(self, new_parent):
4550
self.tree_path = [*new_parent.tree_path, get_random_string(length=32)]
4651

4752
def save(self, *args, **kwargs): # pylint: disable=arguments-differ
53+
tree_path_needs_refresh = False
4854
if not self.tree_path:
55+
tree_path_needs_refresh = True
4956
# Ensure that we have a tree_path set. We set a random one at this point,
5057
# because we don't know whether this node will become the parent of other
5158
# nodes down the track.
52-
self.tree_path = [get_random_string(length=32)]
59+
largest_ltree = models.Subquery(self.__class__.objects.order_by('-tree_path').values('tree_path')[:1])
60+
largest_ltree_root_num = f.Cast(f.Cast(Subpath(largest_ltree, 0, 1), models.CharField()), models.BigIntegerField())
61+
self.tree_path = Text2Ltree(f.Cast(f.Coalesce(largest_ltree_root_num, -2**32) + (2**32), models.CharField()))
5362

5463
# If we haven't changed the parent, save as normal.
5564
if self.__old_tree_path is None:
56-
return super().save(*args, **kwargs)
65+
rv = super().save(*args, **kwargs)
5766

5867
# If we have, use a transaction to avoid other contexts seeing the intermediate
5968
# state where our descendants aren't connected to us.
60-
with atomic():
61-
rv = super().save(*args, **kwargs)
62-
# Move all of our descendants along with us, by substituting our old ltree
63-
# prefix with our new one, in every descendant that has that prefix.
64-
self.__class__.objects.filter(
65-
tree_path__descendant_of=self.__old_tree_path
66-
).update(
67-
tree_path=LtreeConcat(
68-
models.Value('.'.join(self.tree_path)),
69-
Subpath(models.F('tree_path'), len(self.__old_tree_path)),
69+
else:
70+
with atomic():
71+
rv = super().save(*args, **kwargs)
72+
# Move all of our descendants along with us, by substituting our old ltree
73+
# prefix with our new one, in every descendant that has that prefix.
74+
self.__class__.objects.filter(
75+
tree_path__descendant_of=self.__old_tree_path
76+
).update(
77+
tree_path=LtreeConcat(
78+
models.Value('.'.join(self.tree_path)),
79+
Subpath(models.F('tree_path'), len(self.__old_tree_path)),
80+
)
7081
)
71-
)
72-
return rv
82+
83+
if tree_path_needs_refresh:
84+
self.refresh_from_db(fields=('tree_path',))
85+
86+
return rv
7387

7488
@property
7589
def ancestors(self):

django_pgtree/tests.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def animal():
1919

2020

2121
def test_descendants(animal):
22-
assert {x.name for x in animal.descendants} == {
23-
'Mammal', 'Marsupial', 'Cat', 'Dog', 'Koala', 'Kangaroo'}
22+
assert [x.name for x in animal.descendants] == [
23+
'Mammal', 'Cat', 'Dog', 'Marsupial', 'Koala', 'Kangaroo']
2424

2525

2626
def test_ancestors(animal):
2727
koala = T.objects.get(name='Koala')
28-
assert {x.name for x in koala.ancestors} == {'Marsupial', 'Animal'}
28+
assert [x.name for x in koala.ancestors] == ['Animal', 'Marsupial']
2929

3030

3131
def test_parent(animal):
@@ -34,13 +34,12 @@ def test_parent(animal):
3434

3535

3636
def test_children(animal):
37-
assert {x.name for x in animal.children} == {'Mammal', 'Marsupial'}
37+
assert [x.name for x in animal.children] == ['Mammal', 'Marsupial']
3838

3939

4040
def test_family(animal):
4141
mammal = T.objects.get(name='Mammal')
42-
assert {x.name for x in mammal.family} == {
43-
'Animal', 'Mammal', 'Cat', 'Dog'}
42+
assert [x.name for x in mammal.family] == ['Animal', 'Mammal', 'Cat', 'Dog']
4443

4544
def test_reparent(animal):
4645
marsupial = T.objects.get(name='Marsupial')
@@ -51,4 +50,10 @@ def test_reparent(animal):
5150
koala = T.objects.get(name='Koala')
5251
assert koala.parent == marsupial
5352
assert koala.tree_path[:2] == mammal.tree_path
54-
assert mammal in koala.ancestors
53+
assert mammal in koala.ancestors
54+
55+
56+
def test_top_level_ordering(animal):
57+
all_l = list(T.objects.all())
58+
assert all_l[0] == animal
59+
assert all_l[-1].name == 'Plant'

0 commit comments

Comments
 (0)