Skip to content

Commit 3336c7b

Browse files
Feynman Liangmengxr
Feynman Liang
authored andcommitted
[SPARK-8559] [MLLIB] Support Association Rule Generation
Distributed generation of single-consequent association rules from a RDD of frequent itemsets. Tests referenced against `R`'s implementation of A Priori in [arules](http://cran.r-project.org/web/packages/arules/index.html). Author: Feynman Liang <[email protected]> Closes apache#7005 from feynmanliang/fp-association-rules-distributed and squashes the following commits: 466ced0 [Feynman Liang] Refactor AR generation impl 73c1cff [Feynman Liang] Make rule attributes public, remove numTransactions from FreqItemset 80f63ff [Feynman Liang] Change default confidence and optimize imports 04cf5b5 [Feynman Liang] Code review with @mengxr, add R to tests 0cc1a6a [Feynman Liang] Java compatibility test f3c14b5 [Feynman Liang] Fix MiMa test 764375e [Feynman Liang] Fix tests 1187307 [Feynman Liang] Almost working tests b20779b [Feynman Liang] Working implementation 5395c4e [Feynman Liang] Fix imports 2d34405 [Feynman Liang] Partial implementation of distributed ar 83ace4b [Feynman Liang] Local rule generation without pruning complete 69c2c87 [Feynman Liang] Working local implementation, now to parallelize../.. 4e1ec9a [Feynman Liang] Pull FreqItemsets out, refactor type param, tests 69ccedc [Feynman Liang] First implementation of association rule generation
1 parent 70beb80 commit 3336c7b

File tree

5 files changed

+258
-4
lines changed

5 files changed

+258
-4
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.fpm
18+
19+
import scala.reflect.ClassTag
20+
21+
import org.apache.spark.Logging
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.api.java.JavaRDD
24+
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
25+
import org.apache.spark.mllib.fpm.AssociationRules.Rule
26+
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
27+
import org.apache.spark.rdd.RDD
28+
29+
/**
30+
* :: Experimental ::
31+
*
32+
* Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates
33+
* association rules which have a single item as the consequent.
34+
*/
35+
@Experimental
36+
class AssociationRules private (
37+
private var minConfidence: Double) extends Logging with Serializable {
38+
39+
/**
40+
* Constructs a default instance with default parameters {minConfidence = 0.8}.
41+
*/
42+
def this() = this(0.8)
43+
44+
/**
45+
* Sets the minimal confidence (default: `0.8`).
46+
*/
47+
def setMinConfidence(minConfidence: Double): this.type = {
48+
this.minConfidence = minConfidence
49+
this
50+
}
51+
52+
/**
53+
* Computes the association rules with confidence above [[minConfidence]].
54+
* @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
55+
* @return a [[Set[Rule[Item]]] containing the assocation rules.
56+
*/
57+
def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
58+
// For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
59+
val candidates = freqItemsets.flatMap { itemset =>
60+
val items = itemset.items
61+
items.flatMap { item =>
62+
items.partition(_ == item) match {
63+
case (consequent, antecedent) if !antecedent.isEmpty =>
64+
Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
65+
case _ => None
66+
}
67+
}
68+
}
69+
70+
// Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
71+
candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
72+
.map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
73+
new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
74+
}.filter(_.confidence >= minConfidence)
75+
}
76+
77+
def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = {
78+
val tag = fakeClassTag[Item]
79+
run(freqItemsets.rdd)(tag)
80+
}
81+
}
82+
83+
object AssociationRules {
84+
85+
/**
86+
* :: Experimental ::
87+
*
88+
* An association rule between sets of items.
89+
* @param antecedent hypotheses of the rule
90+
* @param consequent conclusion of the rule
91+
* @tparam Item item type
92+
*/
93+
@Experimental
94+
class Rule[Item] private[mllib] (
95+
val antecedent: Array[Item],
96+
val consequent: Array[Item],
97+
freqUnion: Double,
98+
freqAntecedent: Double) extends Serializable {
99+
100+
def confidence: Double = freqUnion.toDouble / freqAntecedent
101+
102+
require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
103+
val sharedItems = antecedent.toSet.intersect(consequent.toSet)
104+
s"A valid association rule must have disjoint antecedent and " +
105+
s"consequent but ${sharedItems} is present in both."
106+
})
107+
}
108+
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
2828
import org.apache.spark.annotation.Experimental
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
31-
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
31+
import org.apache.spark.mllib.fpm.FPGrowth._
3232
import org.apache.spark.rdd.RDD
3333
import org.apache.spark.storage.StorageLevel
3434

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.fpm;
18+
19+
import java.io.Serializable;
20+
21+
import org.junit.After;
22+
import org.junit.Before;
23+
import org.junit.Test;
24+
import com.google.common.collect.Lists;
25+
26+
import org.apache.spark.api.java.JavaRDD;
27+
import org.apache.spark.api.java.JavaSparkContext;
28+
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
29+
30+
31+
public class JavaAssociationRulesSuite implements Serializable {
32+
private transient JavaSparkContext sc;
33+
34+
@Before
35+
public void setUp() {
36+
sc = new JavaSparkContext("local", "JavaFPGrowth");
37+
}
38+
39+
@After
40+
public void tearDown() {
41+
sc.stop();
42+
sc = null;
43+
}
44+
45+
@Test
46+
public void runAssociationRules() {
47+
48+
@SuppressWarnings("unchecked")
49+
JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Lists.newArrayList(
50+
new FreqItemset<String>(new String[] {"a"}, 15L),
51+
new FreqItemset<String>(new String[] {"b"}, 35L),
52+
new FreqItemset<String>(new String[] {"a", "b"}, 18L)
53+
));
54+
55+
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
56+
}
57+
}
58+

mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import org.apache.spark.api.java.JavaRDD;
3131
import org.apache.spark.api.java.JavaSparkContext;
32-
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
3332

3433
public class JavaFPGrowthSuite implements Serializable {
3534
private transient JavaSparkContext sc;
@@ -62,10 +61,10 @@ public void runFPGrowth() {
6261
.setNumPartitions(2)
6362
.run(rdd);
6463

65-
List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
64+
List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
6665
assertEquals(18, freqItemsets.size());
6766

68-
for (FreqItemset<String> itemset: freqItemsets) {
67+
for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
6968
// Test return types.
7069
List<String> items = itemset.javaItems();
7170
long freq = itemset.freq();
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.fpm
18+
19+
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
22+
class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
23+
24+
test("association rules using String type") {
25+
val freqItemsets = sc.parallelize(Seq(
26+
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
27+
(Set("r"), 3L),
28+
(Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
29+
(Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
30+
(Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
31+
(Set("t", "y", "x"), 3L),
32+
(Set("t", "y", "x", "z"), 3L)
33+
).map {
34+
case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
35+
})
36+
37+
val ar = new AssociationRules()
38+
39+
val results1 = ar
40+
.setMinConfidence(0.9)
41+
.run(freqItemsets)
42+
.collect()
43+
44+
/* Verify results using the `R` code:
45+
transactions = as(sapply(
46+
list("r z h k p",
47+
"z y x w v u t s",
48+
"s x o n r",
49+
"x z y m t s q e",
50+
"z",
51+
"x z y r q t p"),
52+
FUN=function(x) strsplit(x," ",fixed=TRUE)),
53+
"transactions")
54+
ars = apriori(transactions,
55+
parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
56+
arsDF = as(ars, "data.frame")
57+
arsDF$support = arsDF$support * length(transactions)
58+
names(arsDF)[names(arsDF) == "support"] = "freq"
59+
> nrow(arsDF)
60+
[1] 23
61+
> sum(arsDF$confidence == 1)
62+
[1] 23
63+
*/
64+
assert(results1.size === 23)
65+
assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
66+
67+
val results2 = ar
68+
.setMinConfidence(0)
69+
.run(freqItemsets)
70+
.collect()
71+
72+
/* Verify results using the `R` code:
73+
ars = apriori(transactions,
74+
parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2))
75+
arsDF = as(ars, "data.frame")
76+
arsDF$support = arsDF$support * length(transactions)
77+
names(arsDF)[names(arsDF) == "support"] = "freq"
78+
nrow(arsDF)
79+
sum(arsDF$confidence == 1)
80+
> nrow(arsDF)
81+
[1] 30
82+
> sum(arsDF$confidence == 1)
83+
[1] 23
84+
*/
85+
assert(results2.size === 30)
86+
assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
87+
}
88+
}
89+

0 commit comments

Comments
 (0)