Skip to content

Reverse foreign key #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ CONTRIBUTORS are and/or have been (alphabetic order):
Github: <https://github.com/bmihelac>
* Pahaz, Blinov
Github: <https://github.com/pahaz>

* Spencer, Samuel
Github: <https://github.com/LegoStormtroopr>
157 changes: 112 additions & 45 deletions reversion_compare/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
admin
~~~~~

Admin extensions for django-reversion-compare

:copyleft: 2012-2015 by the django-reversion-compare team, see AUTHORS for more details.
Expand All @@ -19,6 +19,7 @@
from django.conf.urls import patterns, url
from django.contrib.admin.util import unquote, quote
from django.core.urlresolvers import reverse
from django.db import models
from django.http import Http404
from django.contrib import admin
from django.shortcuts import get_object_or_404, render_to_response
Expand Down Expand Up @@ -82,12 +83,13 @@ def __cmp__(self, other):
raise NotImplemented()

def __eq__(self, other):
assert self.field.get_internal_type() != "ManyToManyField"

if hasattr(self.field,'get_internal_type'):
assert self.field.get_internal_type() != "ManyToManyField"

if self.value != other.value:
return False

if self.field.get_internal_type() == "ForeignKey": # FIXME!
if not hasattr(self.field,'get_internal_type') or self.field.get_internal_type() == "ForeignKey": # FIXME!
if self.version.field_dict != other.version.field_dict:
return False

Expand All @@ -97,26 +99,41 @@ def __ne__(self, other):
return not self.__eq__(other)

def get_related(self):
if self.field.rel is not None:
if getattr(self.field,'rel',None):
obj = self.version.object_version.object
related = getattr(obj, self.field.name)
return related
return getattr(obj, self.field.name,None)

def get_reverse_foreign_key(self):
obj = self.version.object_version.object
#self = getattr(obj, self.field.related_name) #self.field.field_name
if self.has_int_pk and self.field.related_name and hasattr(obj, self.field.related_name):
ids = [v.id for v in getattr(obj, str(self.field.related_name)).all()] # is: version.field_dict[field.name]
else:
return ([],[],[])

# Get the related model of the current field:
related_model = self.field.field.model
return self.get_many_to_something(ids,related_model)

def get_many_to_many(self):
"""
returns a queryset with all many2many objects
"""
if self.field.get_internal_type() != "ManyToManyField": # FIXME!
return (None, None, None)

if self.field.get_internal_type() != "ManyToManyField": # FIXME!
return ([], [], []) # This prevents an error, as None is not iterable
ids = None
if self.has_int_pk:
ids = [int(v) for v in self.value] # is: version.field_dict[field.name]

# get instance of reversion.models.Revision(): A group of related object versions.
old_revision = self.version.revision

# Get the related model of the current field:
related_model = self.field.rel.to
return self.get_many_to_something(ids,related_model)

def get_many_to_something(self,ids,related_model):

# get instance of reversion.models.Revision():
# A group of related object versions.
old_revision = self.version.revision

# Get a queryset with all related objects.
queryset = old_revision.version_set.filter(
Expand Down Expand Up @@ -211,14 +228,14 @@ def __init__(self, field, field_name, obj, version1, version2, manager):
self.adapter = manager.get_adapter(model) # VersionAdapter instance

# is a related field (ForeignKey, ManyToManyField etc.)
self.is_related = self.field.rel is not None
self.is_related = getattr(self.field,'rel',None) is not None

if not self.is_related:
self.follow = None
elif self.field_name in self.adapter.follow:
self.follow = True
else:
self.follow = False
self.follow = False

self.compare_obj1 = CompareObject(field, field_name, obj, version1, self.has_int_pk, self.adapter)
self.compare_obj2 = CompareObject(field, field_name, obj, version2, self.has_int_pk, self.adapter)
Expand All @@ -228,18 +245,18 @@ def __init__(self, field, field_name, obj, version1, version2, manager):

def changed(self):
""" return True if at least one field has changed values. """
if self.field.get_internal_type() == "ManyToManyField": # FIXME!

if hasattr(self.field,'get_internal_type') and self.field.get_internal_type() == "ManyToManyField":
info = self.get_m2m_change_info()
keys = (
"changed_items", "removed_items", "added_items",
"removed_missing_objects", "added_missing_objects"
)
for key in keys:
if info[key]:
return True
return False
return True
return False

return self.compare_obj1 != self.compare_obj2

def _get_result(self, compare_obj, func_name):
Expand All @@ -263,15 +280,36 @@ def get_many_to_many(self):
m2m_data1, m2m_data2 = self._get_both_results("get_many_to_many")
return m2m_data1, m2m_data2

M2O_CHANGE_INFO = None

def get_reverse_foreign_key(self):
return self._get_both_results("get_reverse_foreign_key")

def get_m2o_change_info(self):
if self.M2O_CHANGE_INFO is not None:
return self.M2O_CHANGE_INFO

m2o_data1, m2o_data2 = self.get_reverse_foreign_key()

self.M2O_CHANGE_INFO = self.get_m2s_change_info(m2o_data1, m2o_data2)
return self.M2O_CHANGE_INFO

M2M_CHANGE_INFO = None
def get_m2m_change_info(self):
if self.M2M_CHANGE_INFO is not None:
return self.M2M_CHANGE_INFO

m2m_data1, m2m_data2 = self.get_many_to_many()

result1, missing_objects1, missing_ids1 = m2m_data1
result2, missing_objects2, missing_ids2 = m2m_data2
self.M2M_CHANGE_INFO = self.get_m2s_change_info(m2m_data1, m2m_data2)
return self.M2M_CHANGE_INFO

# Abstract Many-to-Something (either -many or -one) as both
# many2many and many2one relationships looks the same from the refered object.
def get_m2s_change_info(self,obj1_data,obj2_data):

result1, missing_objects1, missing_ids1 = obj1_data
result2, missing_objects2, missing_ids2 = obj2_data

# missing_objects_pk1 = [obj.pk for obj in missing_objects1]
# missing_objects_pk2 = [obj.pk for obj in missing_objects2]
Expand Down Expand Up @@ -349,7 +387,7 @@ def get_m2m_change_info(self):
else:
raise RuntimeError()

self.M2M_CHANGE_INFO = {
return {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. Don't know if its needed to set self.M2M_CHANGE_INFO, see above, line 275

"changed_items": changed_items,
"removed_items": removed_items,
"added_items": added_items,
Expand All @@ -358,7 +396,6 @@ def get_m2m_change_info(self):
"removed_missing_objects": removed_missing_objects,
"added_missing_objects": added_missing_objects,
}
return self.M2M_CHANGE_INFO


def debug(self):
Expand Down Expand Up @@ -400,45 +437,45 @@ def debug(self):
class BaseCompareVersionAdmin(VersionAdmin):
"""
Enhanced version of VersionAdmin with a flexible compare version API.

You can define own method to compare fields in two ways (in this order):

Create a method for a field via the field name, e.g.:
"compare_%s" % field_name

Create a method for every field by his internal type
"compare_%s" % field.get_internal_type()

see: https://docs.djangoproject.com/en/1.4/howto/custom-model-fields/#django.db.models.Field.get_internal_type

If no method defined it would build a simple ndiff from repr().

example:

----------------------------------------------------------------------------
class MyModel(models.Model):
date_created = models.DateTimeField(auto_now_add=True)
last_update = models.DateTimeField(auto_now=True)
user = models.ForeignKey(User)
content = models.TextField()
sub_text = models.ForeignKey(FooBar)

class MyModelAdmin(CompareVersionAdmin):
def compare_DateTimeField(self, obj, version1, version2, value1, value2):
''' compare all model datetime model field in ISO format '''
date1 = value1.isoformat(" ")
date2 = value2.isoformat(" ")
html = html_diff(date1, date2)
return html

def compare_sub_text(self, obj, version1, version2, value1, value2):
''' field_name example '''
return "%s -> %s" % (value1, value2)

----------------------------------------------------------------------------
"""

# Template file used for the compare view:
# Template file used for the compare view:
compare_template = "reversion-compare/compare.html"

# list/tuple of field names for the compare view. Set to None for all existing fields
Expand All @@ -447,10 +484,10 @@ def compare_sub_text(self, obj, version1, version2, value1, value2):
# list/tuple of field names to exclude from compare view.
compare_exclude = None

# change template from django-reversion to add compare selection form:
# change template from django-reversion to add compare selection form:
object_history_template = "reversion-compare/object_history.html"

# sort from new to old as default, see: https://github.com/etianen/django-reversion/issues/77
# sort from new to old as default, see: https://github.com/etianen/django-reversion/issues/77
history_latest_first = True

def get_urls(self):
Expand Down Expand Up @@ -538,6 +575,13 @@ def _get_compare_func(suffix):
html = func(obj_compare)
return html

# Determine if its a reverse field
if obj_compare.field in self.reverse_fields:
func = _get_compare_func("ManyToOneRel")
if func is not None:
html = func(obj_compare)
return html

# Try method in the name scheme: "compare_%s" % field.get_internal_type()
internal_type = obj_compare.field.get_internal_type()
func = _get_compare_func(internal_type)
Expand All @@ -552,9 +596,9 @@ def _get_compare_func(suffix):
def compare(self, obj, version1, version2):
"""
Create a generic html diff from the obj between version1 and version2:

A diff of every changes field values.

This method should be overwritten, to create a nice diff view
coordinated with the model.
"""
Expand All @@ -565,12 +609,30 @@ def compare(self, obj, version1, version2):
concrete_model = obj._meta.concrete_model
fields += concrete_model._meta.many_to_many

# This gathers the related reverse ForeignKey fields, so we can do ManyToOne compares
self.reverse_fields = []
# From: http://stackoverflow.com/questions/19512187/django-list-all-reverse-relations-of-a-model
for field_name in obj._meta.get_all_field_names() :
f = getattr(
obj._meta.get_field_by_name(field_name)[0],
'field',
None
)
if isinstance(f, models.ForeignKey) and f not in fields:
self.reverse_fields.append(f.rel)

fields += self.reverse_fields

has_unfollowed_fields = False

for field in fields:
#logger.debug("%s %s %s", field, field.db_type, field.get_internal_type())

field_name = field.name
try:
field_name = field.name
except:
# is a reverse FK field
field_name = field.field_name

if self.compare_fields and field_name not in self.compare_fields:
continue
Expand Down Expand Up @@ -668,7 +730,7 @@ def compare_view(self, request, object_id, extra_context=None):

class CompareVersionAdmin(BaseCompareVersionAdmin):
"""
expand the base class with prepered compare methods.
expand the base class with prepared compare methods.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx! ;)

"""
def generic_add_remove(self, raw_value1, raw_value2, value1, value2):
if raw_value1 is None:
Expand Down Expand Up @@ -698,6 +760,11 @@ def simple_compare_ManyToManyField(self, obj_compare):
html = html_diff(old, new)
return html

def compare_ManyToOneRel(self, obj_compare):
change_info = obj_compare.get_m2o_change_info()
context = {"change_info": change_info}
return render_to_string("reversion-compare/compare_generic_many_to_many.html", context)

def compare_ManyToManyField(self, obj_compare):
""" create a table for m2m compare """
change_info = obj_compare.get_m2m_change_info()
Expand All @@ -710,7 +777,7 @@ def compare_FileField(self, obj_compare):
value1 = obj_compare.value1
value2 = obj_compare.value2

# FIXME: Needed to not get 'The 'file' attribute has no file associated with it.'
# FIXME: Needed to not get 'The 'file' attribute has no file associated with it.'
if value1:
value1 = value1.url
else:
Expand Down Expand Up @@ -752,4 +819,4 @@ class VersionAdmin(admin.ModelAdmin):
list_filter = ("content_type", "format")
search_fields = ("object_repr", "serialized_data")

admin.site.register(Version, VersionAdmin)
admin.site.register(Version, VersionAdmin)
Loading