Merge similar algorithms into roles_is_member_of().
authorNoah Misch <[email protected]>
Fri, 26 Mar 2021 17:42:16 +0000 (10:42 -0700)
committerNoah Misch <[email protected]>
Fri, 26 Mar 2021 17:42:16 +0000 (10:42 -0700)
The next commit would have complicated two or three algorithms, so take
this opportunity to consolidate.  No functional changes.

Reviewed by John Naylor.

Discussion: https://postgr.es/m/20201228043148[email protected]

src/backend/utils/adt/acl.c

index c7f029e2186a1cd73683187b516cd37b54850d12..e6b4bdbd7685b590e3802438ed753906417f0586 100644 (file)
@@ -50,32 +50,24 @@ typedef struct
 /*
  * We frequently need to test whether a given role is a member of some other
  * role.  In most of these tests the "given role" is the same, namely the
- * active current user.  So we can optimize it by keeping a cached list of
- * all the roles the "given role" is a member of, directly or indirectly.
- *
- * There are actually two caches, one computed under "has_privs" rules
- * (do not recurse where rolinherit isn't true) and one computed under
- * "is_member" rules (recurse regardless of rolinherit).
+ * active current user.  So we can optimize it by keeping cached lists of all
+ * the roles the "given role" is a member of, directly or indirectly.
  *
  * Possibly this mechanism should be generalized to allow caching membership
  * info for multiple roles?
  *
- * The has_privs cache is:
- * cached_privs_role is the role OID the cache is for.
- * cached_privs_roles is an OID list of roles that cached_privs_role
- *     has the privileges of (always including itself).
- * The cache is valid if cached_privs_role is not InvalidOid.
- *
- * The is_member cache is similarly:
- * cached_member_role is the role OID the cache is for.
- * cached_membership_roles is an OID list of roles that cached_member_role
- *     is a member of (always including itself).
- * The cache is valid if cached_member_role is not InvalidOid.
+ * Each element of cached_roles is an OID list of constituent roles for the
+ * corresponding element of cached_role (always including the cached_role
+ * itself).  One cache has ROLERECURSE_PRIVS semantics, and the other has
+ * ROLERECURSE_MEMBERS semantics.
  */
-static Oid cached_privs_role = InvalidOid;
-static List *cached_privs_roles = NIL;
-static Oid cached_member_role = InvalidOid;
-static List *cached_membership_roles = NIL;
+enum RoleRecurseType
+{
+   ROLERECURSE_PRIVS = 0,      /* recurse if rolinherit */
+   ROLERECURSE_MEMBERS = 1     /* recurse unconditionally */
+};
+static Oid cached_role[] = {InvalidOid, InvalidOid};
+static List *cached_roles[] = {NIL, NIL};
 
 
 static const char *getid(const char *s, char *n);
@@ -4675,8 +4667,8 @@ initialize_acl(void)
    {
        /*
         * In normal mode, set a callback on any syscache invalidation of rows
-        * of pg_auth_members (for each AUTHMEM search in this file) or
-        * pg_authid (for has_rolinherit())
+        * of pg_auth_members (for roles_is_member_of()) or pg_authid (for
+        * has_rolinherit())
         */
        CacheRegisterSyscacheCallback(AUTHMEMROLEMEM,
                                      RoleMembershipCacheCallback,
@@ -4695,8 +4687,8 @@ static void
 RoleMembershipCacheCallback(Datum arg, int cacheid, uint32 hashvalue)
 {
    /* Force membership caches to be recomputed on next use */
-   cached_privs_role = InvalidOid;
-   cached_member_role = InvalidOid;
+   cached_role[ROLERECURSE_PRIVS] = InvalidOid;
+   cached_role[ROLERECURSE_MEMBERS] = InvalidOid;
 }
 
 
@@ -4718,30 +4710,35 @@ has_rolinherit(Oid roleid)
 
 
 /*
- * Get a list of roles that the specified roleid has the privileges of
+ * Get a list of roles that the specified roleid is a member of
  *
- * This is defined not to recurse through roles that don't have rolinherit
- * set; for such roles, membership implies the ability to do SET ROLE, but
- * the privileges are not available until you've done so.
+ * Type ROLERECURSE_PRIVS recurses only through roles that have rolinherit
+ * set, while ROLERECURSE_MEMBERS recurses through all roles.  This sets
+ * *is_admin==true if and only if role "roleid" has an ADMIN OPTION membership
+ * in role "admin_of".
  *
  * Since indirect membership testing is relatively expensive, we cache
  * a list of memberships.  Hence, the result is only guaranteed good until
- * the next call of roles_has_privs_of()!
+ * the next call of roles_is_member_of()!
  *
  * For the benefit of select_best_grantor, the result is defined to be
  * in breadth-first order, ie, closer relationships earlier.
  */
 static List *
-roles_has_privs_of(Oid roleid)
+roles_is_member_of(Oid roleid, enum RoleRecurseType type,
+                  Oid admin_of, bool *is_admin)
 {
    List       *roles_list;
    ListCell   *l;
-   List       *new_cached_privs_roles;
+   List       *new_cached_roles;
    MemoryContext oldctx;
 
-   /* If cache is already valid, just return the list */
-   if (OidIsValid(cached_privs_role) && cached_privs_role == roleid)
-       return cached_privs_roles;
+   Assert(OidIsValid(admin_of) == PointerIsValid(is_admin));
+
+   /* If cache is valid and ADMIN OPTION not sought, just return the list */
+   if (cached_role[type] == roleid && !OidIsValid(admin_of) &&
+       OidIsValid(cached_role[type]))
+       return cached_roles[type];
 
    /*
     * Find all the roles that roleid is a member of, including multi-level
@@ -4762,9 +4759,8 @@ roles_has_privs_of(Oid roleid)
        CatCList   *memlist;
        int         i;
 
-       /* Ignore non-inheriting roles */
-       if (!has_rolinherit(memberid))
-           continue;
+       if (type == ROLERECURSE_PRIVS && !has_rolinherit(memberid))
+           continue;           /* ignore non-inheriting roles */
 
        /* Find roles that memberid is directly a member of */
        memlist = SearchSysCacheList1(AUTHMEMMEMROLE,
@@ -4775,83 +4771,13 @@ roles_has_privs_of(Oid roleid)
            Oid         otherid = ((Form_pg_auth_members) GETSTRUCT(tup))->roleid;
 
            /*
-            * Even though there shouldn't be any loops in the membership
-            * graph, we must test for having already seen this role. It is
-            * legal for instance to have both A->B and A->C->B.
+            * While otherid==InvalidOid shouldn't appear in the catalog, the
+            * OidIsValid() avoids crashing if that arises.
             */
-           roles_list = list_append_unique_oid(roles_list, otherid);
-       }
-       ReleaseSysCacheList(memlist);
-   }
-
-   /*
-    * Copy the completed list into TopMemoryContext so it will persist.
-    */
-   oldctx = MemoryContextSwitchTo(TopMemoryContext);
-   new_cached_privs_roles = list_copy(roles_list);
-   MemoryContextSwitchTo(oldctx);
-   list_free(roles_list);
-
-   /*
-    * Now safe to assign to state variable
-    */
-   cached_privs_role = InvalidOid; /* just paranoia */
-   list_free(cached_privs_roles);
-   cached_privs_roles = new_cached_privs_roles;
-   cached_privs_role = roleid;
-
-   /* And now we can return the answer */
-   return cached_privs_roles;
-}
-
-
-/*
- * Get a list of roles that the specified roleid is a member of
- *
- * This is defined to recurse through roles regardless of rolinherit.
- *
- * Since indirect membership testing is relatively expensive, we cache
- * a list of memberships.  Hence, the result is only guaranteed good until
- * the next call of roles_is_member_of()!
- */
-static List *
-roles_is_member_of(Oid roleid)
-{
-   List       *roles_list;
-   ListCell   *l;
-   List       *new_cached_membership_roles;
-   MemoryContext oldctx;
-
-   /* If cache is already valid, just return the list */
-   if (OidIsValid(cached_member_role) && cached_member_role == roleid)
-       return cached_membership_roles;
-
-   /*
-    * Find all the roles that roleid is a member of, including multi-level
-    * recursion.  The role itself will always be the first element of the
-    * resulting list.
-    *
-    * Each element of the list is scanned to see if it adds any indirect
-    * memberships.  We can use a single list as both the record of
-    * already-found memberships and the agenda of roles yet to be scanned.
-    * This is a bit tricky but works because the foreach() macro doesn't
-    * fetch the next list element until the bottom of the loop.
-    */
-   roles_list = list_make1_oid(roleid);
-
-   foreach(l, roles_list)
-   {
-       Oid         memberid = lfirst_oid(l);
-       CatCList   *memlist;
-       int         i;
-
-       /* Find roles that memberid is directly a member of */
-       memlist = SearchSysCacheList1(AUTHMEMMEMROLE,
-                                     ObjectIdGetDatum(memberid));
-       for (i = 0; i < memlist->n_members; i++)
-       {
-           HeapTuple   tup = &memlist->members[i]->tuple;
-           Oid         otherid = ((Form_pg_auth_members) GETSTRUCT(tup))->roleid;
+           if (otherid == admin_of &&
+               ((Form_pg_auth_members) GETSTRUCT(tup))->admin_option &&
+               OidIsValid(admin_of))
+               *is_admin = true;
 
            /*
             * Even though there shouldn't be any loops in the membership
@@ -4867,20 +4793,20 @@ roles_is_member_of(Oid roleid)
     * Copy the completed list into TopMemoryContext so it will persist.
     */
    oldctx = MemoryContextSwitchTo(TopMemoryContext);
-   new_cached_membership_roles = list_copy(roles_list);
+   new_cached_roles = list_copy(roles_list);
    MemoryContextSwitchTo(oldctx);
    list_free(roles_list);
 
    /*
     * Now safe to assign to state variable
     */
-   cached_member_role = InvalidOid;    /* just paranoia */
-   list_free(cached_membership_roles);
-   cached_membership_roles = new_cached_membership_roles;
-   cached_member_role = roleid;
+   cached_role[type] = InvalidOid; /* just paranoia */
+   list_free(cached_roles[type]);
+   cached_roles[type] = new_cached_roles;
+   cached_role[type] = roleid;
 
    /* And now we can return the answer */
-   return cached_membership_roles;
+   return cached_roles[type];
 }
 
 
@@ -4906,7 +4832,9 @@ has_privs_of_role(Oid member, Oid role)
     * Find all the roles that member has the privileges of, including
     * multi-level recursion, then see if target role is any one of them.
     */
-   return list_member_oid(roles_has_privs_of(member), role);
+   return list_member_oid(roles_is_member_of(member, ROLERECURSE_PRIVS,
+                                             InvalidOid, NULL),
+                          role);
 }
 
 
@@ -4930,7 +4858,9 @@ is_member_of_role(Oid member, Oid role)
     * Find all the roles that member is a member of, including multi-level
     * recursion, then see if target role is any one of them.
     */
-   return list_member_oid(roles_is_member_of(member), role);
+   return list_member_oid(roles_is_member_of(member, ROLERECURSE_MEMBERS,
+                                             InvalidOid, NULL),
+                          role);
 }
 
 /*
@@ -4964,7 +4894,9 @@ is_member_of_role_nosuper(Oid member, Oid role)
     * Find all the roles that member is a member of, including multi-level
     * recursion, then see if target role is any one of them.
     */
-   return list_member_oid(roles_is_member_of(member), role);
+   return list_member_oid(roles_is_member_of(member, ROLERECURSE_MEMBERS,
+                                             InvalidOid, NULL),
+                          role);
 }
 
 
@@ -4977,8 +4909,6 @@ bool
 is_admin_of_role(Oid member, Oid role)
 {
    bool        result = false;
-   List       *roles_list;
-   ListCell   *l;
 
    if (superuser_arg(member))
        return true;
@@ -5016,44 +4946,7 @@ is_admin_of_role(Oid member, Oid role)
        return member == GetSessionUserId() &&
            !InLocalUserIdChange() && !InSecurityRestrictedOperation();
 
-   /*
-    * Find all the roles that member is a member of, including multi-level
-    * recursion.  We build a list in the same way that is_member_of_role does
-    * to track visited and unvisited roles.
-    */
-   roles_list = list_make1_oid(member);
-
-   foreach(l, roles_list)
-   {
-       Oid         memberid = lfirst_oid(l);
-       CatCList   *memlist;
-       int         i;
-
-       /* Find roles that memberid is directly a member of */
-       memlist = SearchSysCacheList1(AUTHMEMMEMROLE,
-                                     ObjectIdGetDatum(memberid));
-       for (i = 0; i < memlist->n_members; i++)
-       {
-           HeapTuple   tup = &memlist->members[i]->tuple;
-           Oid         otherid = ((Form_pg_auth_members) GETSTRUCT(tup))->roleid;
-
-           if (otherid == role &&
-               ((Form_pg_auth_members) GETSTRUCT(tup))->admin_option)
-           {
-               /* Found what we came for, so can stop searching */
-               result = true;
-               break;
-           }
-
-           roles_list = list_append_unique_oid(roles_list, otherid);
-       }
-       ReleaseSysCacheList(memlist);
-       if (result)
-           break;
-   }
-
-   list_free(roles_list);
-
+   (void) roles_is_member_of(member, ROLERECURSE_MEMBERS, role, &result);
    return result;
 }
 
@@ -5125,10 +5018,11 @@ select_best_grantor(Oid roleId, AclMode privileges,
    /*
     * Otherwise we have to do a careful search to see if roleId has the
     * privileges of any suitable role.  Note: we can hang onto the result of
-    * roles_has_privs_of() throughout this loop, because aclmask_direct()
+    * roles_is_member_of() throughout this loop, because aclmask_direct()
     * doesn't query any role memberships.
     */
-   roles_list = roles_has_privs_of(roleId);
+   roles_list = roles_is_member_of(roleId, ROLERECURSE_PRIVS,
+                                   InvalidOid, NULL);
 
    /* initialize candidate result as default */
    *grantorId = roleId;