Skip to content

Commit a3ac5c0

Browse files
author
matt dannenberg
committed
SERVER-10086 $set operations for agg
1 parent 7fc597e commit a3ac5c0

File tree

4 files changed

+628
-0
lines changed

4 files changed

+628
-0
lines changed

src/mongo/db/pipeline/expression.cpp

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ namespace mongo {
283283
{"$not", ExpressionNot::create, OpDesc::FIXED_COUNT, 1},
284284
{"$or", ExpressionOr::create, 0},
285285
{"$second", ExpressionSecond::create, OpDesc::FIXED_COUNT, 1},
286+
{"$setDifference", ExpressionSetDifference::create, OpDesc::FIXED_COUNT, 2},
287+
{"$setEquals", ExpressionSetEquals::create, 0},
288+
{"$setIntersection", ExpressionSetIntersection::create, 0},
289+
{"$setIsSubset", ExpressionSetIsSubset::create, OpDesc::FIXED_COUNT, 2},
290+
{"$setUnion", ExpressionSetUnion::create, 0},
286291
{"$strcasecmp", ExpressionStrcasecmp::create, OpDesc::FIXED_COUNT, 2},
287292
{"$substr", ExpressionSubstr::create, OpDesc::FIXED_COUNT, 3},
288293
{"$subtract", ExpressionSubtract::create, OpDesc::FIXED_COUNT, 2},
@@ -2315,6 +2320,219 @@ namespace mongo {
23152320
return "$second";
23162321
}
23172322

2323+
2324+
2325+
namespace {
2326+
ValueSet arrayToSet(const Value& val) {
2327+
const vector<Value>& array = val.getArray();
2328+
return ValueSet(array.begin(), array.end());
2329+
}
2330+
}
2331+
2332+
/* ----------------------- ExpressionSetDifference ---------------------------- */
2333+
2334+
intrusive_ptr<ExpressionNary> ExpressionSetDifference::create() {
2335+
return new ExpressionSetDifference();
2336+
}
2337+
2338+
void ExpressionSetDifference::addOperand(const intrusive_ptr<Expression> &pExpression) {
2339+
checkArgLimit(2);
2340+
ExpressionNary::addOperand(pExpression);
2341+
}
2342+
2343+
Value ExpressionSetDifference::evaluateInternal(const Variables& vars) const {
2344+
checkArgCount(2);
2345+
const Value lhs = vpOperand[0]->evaluateInternal(vars);
2346+
const Value rhs = vpOperand[1]->evaluateInternal(vars);
2347+
2348+
if (lhs.nullish() || rhs.nullish()) {
2349+
return Value(BSONNULL);
2350+
}
2351+
2352+
uassert(16962, str::stream() << "both operands of $setDifference must be arrays. First "
2353+
<< "argument is of type: " << lhs.getType(),
2354+
lhs.getType() == Array);
2355+
uassert(16963, str::stream() << "both operands of $setDifference must be arrays. Second "
2356+
<< "argument is of type: " << rhs.getType(),
2357+
rhs.getType() == Array);
2358+
2359+
const ValueSet rhsSet = arrayToSet(rhs);
2360+
const vector<Value>& lhsArray = lhs.getArray();
2361+
vector<Value> returnVec;
2362+
2363+
for (vector<Value>::const_iterator it = lhsArray.begin(); it != lhsArray.end(); ++it) {
2364+
if (!rhsSet.count(*it)) {
2365+
returnVec.push_back(*it);
2366+
}
2367+
}
2368+
return Value::consume(returnVec);
2369+
}
2370+
2371+
const char *ExpressionSetDifference::getOpName() const {
2372+
return "$setDifference";
2373+
}
2374+
2375+
/* ----------------------- ExpressionSetEquals ---------------------------- */
2376+
2377+
intrusive_ptr<ExpressionNary> ExpressionSetEquals::create() {
2378+
return new ExpressionSetEquals();
2379+
}
2380+
2381+
Value ExpressionSetEquals::evaluateInternal(const Variables& vars) const {
2382+
const size_t n = vpOperand.size();
2383+
uassert(16974, str::stream() << "$setEquals needs at least two arguments had: " << n,
2384+
n >= 2);
2385+
std::set<Value> lhs;
2386+
2387+
for (size_t i = 0; i < n; i++) {
2388+
const Value nextEntry = vpOperand[i]->evaluateInternal(vars);
2389+
uassert(16971, str::stream() << "All operands of $setIntersection must be arrays. One "
2390+
<< "argument is of type: " << nextEntry.getType(),
2391+
nextEntry.getType() == Array);
2392+
2393+
if (i == 0) {
2394+
lhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
2395+
}
2396+
else {
2397+
const std::set<Value> rhs(nextEntry.getArray().begin(), nextEntry.getArray().end());
2398+
if (lhs != rhs) {
2399+
return Value(false);
2400+
}
2401+
}
2402+
}
2403+
return Value(true);
2404+
}
2405+
2406+
const char *ExpressionSetEquals::getOpName() const {
2407+
return "$setEquals";
2408+
}
2409+
2410+
/* ----------------------- ExpressionSetIntersection ---------------------------- */
2411+
2412+
intrusive_ptr<ExpressionNary> ExpressionSetIntersection::create() {
2413+
return new ExpressionSetIntersection();
2414+
}
2415+
2416+
Value ExpressionSetIntersection::evaluateInternal(const Variables& vars) const {
2417+
const size_t n = vpOperand.size();
2418+
ValueSet currentIntersection;
2419+
for (size_t i = 0; i < n; i++) {
2420+
const Value nextEntry = vpOperand[i]->evaluateInternal(vars);
2421+
if (nextEntry.nullish()) {
2422+
return Value(BSONNULL);
2423+
}
2424+
uassert(16966, str::stream() << "All operands of $setIntersection must be arrays. One "
2425+
<< "argument is of type: " << nextEntry.getType(),
2426+
nextEntry.getType() == Array);
2427+
2428+
if (i == 0) {
2429+
currentIntersection.insert(nextEntry.getArray().begin(),
2430+
nextEntry.getArray().end());
2431+
}
2432+
else {
2433+
ValueSet nextSet = arrayToSet(nextEntry);
2434+
if (currentIntersection.size() > nextSet.size()) {
2435+
// to iterate over whichever is the smaller set
2436+
nextSet.swap(currentIntersection);
2437+
}
2438+
2439+
for (ValueSet::iterator it = currentIntersection.begin();
2440+
it != currentIntersection.end(); ++it) {
2441+
if (!nextSet.count(*it)) {
2442+
currentIntersection.erase(*it);
2443+
}
2444+
}
2445+
}
2446+
if (currentIntersection.empty()) {
2447+
break;
2448+
}
2449+
}
2450+
vector<Value> result = vector<Value>(currentIntersection.begin(),
2451+
currentIntersection.end());
2452+
return Value::consume(result);
2453+
}
2454+
2455+
const char *ExpressionSetIntersection::getOpName() const {
2456+
return "$setIntersection";
2457+
}
2458+
2459+
intrusive_ptr<ExpressionNary> (*ExpressionSetIntersection::getFactory() const)() {
2460+
return ExpressionSetIntersection::create;
2461+
}
2462+
2463+
/* ----------------------- ExpressionSetIsSubset ---------------------------- */
2464+
2465+
intrusive_ptr<ExpressionNary> ExpressionSetIsSubset::create() {
2466+
return new ExpressionSetIsSubset();
2467+
}
2468+
2469+
void ExpressionSetIsSubset::addOperand(const intrusive_ptr<Expression> &pExpression) {
2470+
checkArgLimit(2);
2471+
ExpressionNary::addOperand(pExpression);
2472+
}
2473+
2474+
Value ExpressionSetIsSubset::evaluateInternal(const Variables& vars) const {
2475+
checkArgCount(2);
2476+
const Value lhs = vpOperand[0]->evaluateInternal(vars);
2477+
const Value rhs = vpOperand[1]->evaluateInternal(vars);
2478+
2479+
uassert(16968, str::stream() << "both operands of $setIsSubset must be arrays. First "
2480+
<< "argument is of type: " << lhs.getType(),
2481+
lhs.getType() == Array);
2482+
uassert(16969, str::stream() << "both operands of $setIsSubset must be arrays. Second "
2483+
<< "argument is of type: " << rhs.getType(),
2484+
rhs.getType() == Array);
2485+
2486+
const vector<Value>& potentialSubset = lhs.getArray();
2487+
const ValueSet& fullSet = arrayToSet(rhs);
2488+
2489+
// do not shortcircuit when potentialSubset.size() > fullSet.size()
2490+
// because potentialSubset can have redundant entries
2491+
for (vector<Value>::const_iterator it = potentialSubset.begin();
2492+
it != potentialSubset.end(); ++it) {
2493+
if (!fullSet.count(*it)) {
2494+
return Value(false);
2495+
}
2496+
}
2497+
return Value(true);
2498+
}
2499+
2500+
const char *ExpressionSetIsSubset::getOpName() const {
2501+
return "$setIsSubset";
2502+
}
2503+
2504+
/* ----------------------- ExpressionSetUnion ---------------------------- */
2505+
2506+
intrusive_ptr<ExpressionNary> ExpressionSetUnion::create() {
2507+
return new ExpressionSetUnion();
2508+
}
2509+
2510+
Value ExpressionSetUnion::evaluateInternal(const Variables& vars) const {
2511+
ValueSet unionedSet;
2512+
const size_t n = vpOperand.size();
2513+
for (size_t i = 0; i < n; i++) {
2514+
const Value newEntries = vpOperand[i]->evaluateInternal(vars);
2515+
if (newEntries.nullish()) {
2516+
return Value(BSONNULL);
2517+
}
2518+
uassert(16970, str::stream() << "All operands of $setUnion must be arrays. One argument"
2519+
<< " is of type: " << newEntries.getType(),
2520+
newEntries.getType() == Array);
2521+
2522+
unionedSet.insert(newEntries.getArray().begin(), newEntries.getArray().end());
2523+
}
2524+
vector<Value> result = vector<Value>(unionedSet.begin(), unionedSet.end());
2525+
return Value::consume(result);
2526+
}
2527+
2528+
const char *ExpressionSetUnion::getOpName() const {
2529+
return "$setUnion";
2530+
}
2531+
2532+
intrusive_ptr<ExpressionNary> (*ExpressionSetUnion::getFactory() const)() {
2533+
return ExpressionSetUnion::create;
2534+
}
2535+
23182536
/* ----------------------- ExpressionStrcasecmp ---------------------------- */
23192537

23202538
intrusive_ptr<ExpressionNary> ExpressionStrcasecmp::create() {

src/mongo/db/pipeline/expression.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,81 @@ namespace mongo {
921921
};
922922

923923

924+
class ExpressionSetDifference :
925+
public ExpressionNary {
926+
public:
927+
// virtuals from ExpressionNary
928+
virtual Value evaluateInternal(const Variables& vars) const;
929+
virtual const char *getOpName() const;
930+
virtual void addOperand(const intrusive_ptr<Expression> &pExpression);
931+
932+
static intrusive_ptr<ExpressionNary> create();
933+
934+
private:
935+
ExpressionSetDifference() {}
936+
};
937+
938+
939+
class ExpressionSetEquals :
940+
public ExpressionNary {
941+
public:
942+
// virtuals from ExpressionNary
943+
virtual Value evaluateInternal(const Variables& vars) const;
944+
virtual const char *getOpName() const;
945+
946+
static intrusive_ptr<ExpressionNary> create();
947+
948+
private:
949+
ExpressionSetEquals() {}
950+
};
951+
952+
953+
class ExpressionSetIntersection :
954+
public ExpressionNary {
955+
public:
956+
// virtuals from ExpressionNary
957+
virtual Value evaluateInternal(const Variables& vars) const;
958+
virtual const char *getOpName() const;
959+
virtual intrusive_ptr<ExpressionNary> (*getFactory() const)();
960+
961+
static intrusive_ptr<ExpressionNary> create();
962+
963+
private:
964+
ExpressionSetIntersection() {}
965+
};
966+
967+
968+
class ExpressionSetIsSubset :
969+
public ExpressionNary {
970+
public:
971+
// virtuals from ExpressionNary
972+
virtual Value evaluateInternal(const Variables& vars) const;
973+
virtual const char *getOpName() const;
974+
virtual void addOperand(const intrusive_ptr<Expression> &pExpression);
975+
976+
static intrusive_ptr<ExpressionNary> create();
977+
978+
private:
979+
ExpressionSetIsSubset() {}
980+
};
981+
982+
983+
class ExpressionSetUnion :
984+
public ExpressionNary {
985+
public:
986+
// virtuals from ExpressionNary
987+
// virtual intrusive_ptr<Expression> optimize();
988+
virtual Value evaluateInternal(const Variables& vars) const;
989+
virtual const char *getOpName() const;
990+
virtual intrusive_ptr<ExpressionNary> (*getFactory() const)();
991+
992+
static intrusive_ptr<ExpressionNary> create();
993+
994+
private:
995+
ExpressionSetUnion() {}
996+
};
997+
998+
924999
class ExpressionStrcasecmp :
9251000
public ExpressionNary {
9261001
public:

src/mongo/db/pipeline/value.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include "mongo/db/pipeline/value_internal.h"
20+
#include "mongo/platform/unordered_set.h"
2021

2122
namespace mongo {
2223
class BSONElement;
@@ -194,6 +195,14 @@ namespace mongo {
194195
}
195196
return (Value::compare(v1, v2) == 0);
196197
}
198+
199+
friend bool operator!=(const Value& v1, const Value& v2) {
200+
return !(v1 == v2);
201+
}
202+
203+
friend bool operator<(const Value& lhs, const Value& rhs) {
204+
return (Value::compare(lhs, rhs) < 0);
205+
}
197206

198207
/// This is for debugging, logging, etc. See getString() for how to extract a string.
199208
string toString() const;
@@ -253,6 +262,8 @@ namespace mongo {
253262
friend class MutableValue; // gets and sets _storage.genericRCPtr
254263
};
255264
BOOST_STATIC_ASSERT(sizeof(Value) == 16);
265+
266+
typedef unordered_set<Value, Value::Hash> ValueSet;
256267
}
257268

258269
namespace std {

0 commit comments

Comments
 (0)