关联规则用于发现交易数据中,不同商品之间的关系,这些规则反映了顾客的购买行为模式。如顾客经常在购买A商品的时候也会购买B商品,著名的“啤酒与尿布”的案例就是关联规则的成功应用案例
导语
Apriori算法是关联规则的基本算法,很多用于发现关联规则的算法都是基于Apriori算法,但Apriori算法需要多次访问数据库,具有严重的性能问题。FP-Growth算法只需要两次扫描数据库,相比于Apriori减少了I/O操作,克服了Apriori算法需要多次扫描数据库的问题。本文采用如下的样例数据
A;B;E;
B;D;
B;C;
A;B;D
A;C;
B;C;
A;C;
A;B;C;E;
A;B;C;
FP-Growth生成FP-Tree
FP-Growth算法将数据库中的频繁项集压缩到一颗频繁模式树中,同时保持了频繁项集之间的关联关系。通过对该频繁模式树进行挖掘,可以得到频繁项集。其过程如下:
- 第一次扫描数据库,产生频繁1项集,并对产生的频繁项集按照频数降序排列,并剪枝支持数低于阀值的元素。处理后得到L集合,
- 第二次扫描数据库,对数据库的每个交易事务中的项按照L集合中项出现的顺序排序,生成FP-Tree。
产生fp-tree的步骤可以分解为如下:
从FP-Tree挖掘频繁项集
从FP-Tree重可以挖掘出频繁项集,其过程如下:
从频繁1项集链表中按照逆序开始,链表可以追溯到每个具有相同项的节点。
- 从链表中找到项“E”,追溯出FP-Tree中有两个带“E”的节点,由这两个节点分别向上(parent)追溯,形成两条模式:<E,C,A,B;1>,<E,A,B;1>.
- 由这两条模式得到项“E”的条件模式<A,B;2>.
- 根据条件模式,得到项“E”的频繁项集(不包含频繁1项集):<E,A;2>,<E,B;2>,<E,A,B;2>
- 然后一次得到项“D”,“C”,“A”。
FP-Growth算法简单实现
下面是该算法的一个简单的实现(其中由FP-Tree挖掘频繁项模式实现比较繁琐,计算关联规则时可以通过FP-Tree可以得到,但本文是通过得到的频繁项集得到的)。
package association;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import set.SetUtils;
/**
*
* @author jeff
*
*/
public class FPTree {
public static String SPLIT = ";";
public static int F = 2;
public static double C = 0.7;
public static List<String> transList = new ArrayList<String>();
Map<String, TreeNode> map = new HashMap<String, FPTree.TreeNode>();// 用来
TreeNode root = null;
Map<String, Integer> f1Items;
List<String> sortedList;
public Map<String, TreeNode> getMap() {
return map;
}
public TreeNode getRoot() {
return root;
}
static {
transList.add("A;B;E;");
transList.add("B;D;");
transList.add("B;C;");
transList.add("A;B;D;");
transList.add("A;C;");
transList.add("B;C;");
transList.add("A;C;");
transList.add("A;B;C;E;");
transList.add("A;B;C;");
}
/**
* FP-Tree的节点结构
* @author jeff
*
*/
public class TreeNode {
public String key;
public Integer value;
public TreeNode(String key, Integer value) {
this.key = key;
this.value = value;
this.children = new ArrayList<FPTree.TreeNode>();
}
public TreeNode getParent() {
return this.parent;
}
public void setParent(TreeNode parent) {
this.parent = parent;
}
public List<TreeNode> getChildren() {
return this.children;
}
public void setChildren(List<TreeNode> children) {
this.children = children;
}
public void addChild(TreeNode child) {
if (this.children == null) {
this.children = new ArrayList<FPTree.TreeNode>();
}
this.children.add(child);
child.setParent(this);
}
public TreeNode getLinked() {
return this.linked;
}
public void setLinked(TreeNode linked) {
this.linked = linked;
}
public TreeNode parent;
public List<TreeNode> children;
public TreeNode linked;
}
/**
* 扫描数据库得到频繁1项集
* @return
*/
private Map<String, Integer> getF1Item() {
f1Items = new HashMap<String, Integer>();
for (String tran : transList) {
String[] items = tran.split(SPLIT);
for (String item : items) {
if (f1Items.get(item + SPLIT) == null) {
f1Items.put(item + SPLIT, 1);
} else {
f1Items.put(item + SPLIT, f1Items.get(item + SPLIT) + 1);
}
}
}
Map<String, Integer> res = new HashMap<String, Integer>();
for (Map.Entry<String, Integer> entry : f1Items.entrySet()) {
if (entry.getValue() >= F) {
res.put(entry.getKey(), entry.getValue());
}
}
return res;
}
/**
* 对频繁1项集按照频次进行排序
* @param oriMap
* @return
*/
private List<String> sortMapByValue(Map<String, Integer> oriMap) {
List<String> sortedNodes = new ArrayList<String>();
if (oriMap != null && !oriMap.isEmpty()) {
List<Map.Entry<String, Integer>> entryList = new ArrayList<Map.Entry<String, Integer>>(oriMap.entrySet());
Collections.sort(entryList, new Comparator<Map.Entry<String, Integer>>() {
public int compare(Entry<String, Integer> entry1, Entry<String, Integer> entry2) {
int value1 = 0, value2 = 0;
try {
value1 = entry1.getValue();
value2 = entry2.getValue();
} catch (NumberFormatException e) {
value1 = 0;
value2 = 0;
}
return value2 - value1;
}
});
Iterator<Map.Entry<String, Integer>> iter = entryList.iterator();
Map.Entry<String, Integer> tmpEntry = null;
while (iter.hasNext()) {
tmpEntry = iter.next();
sortedNodes.add(tmpEntry.getKey());
}
}
return sortedNodes;
}
/**
* 根据排序后的频繁1项集,对事务中包含的项集进行排序
* @param tran
* @param sortedItems
* @return
*/
private List<String> getSortedTrans(String tran, List<String> sortedItems) {
List<String> sortedTranItems = new ArrayList<String>();
String[] items = tran.split(SPLIT);
for (String item : sortedItems) {
for (String tItem : items) {
if (item.equals(tItem + SPLIT)) {
sortedTranItems.add(tItem + SPLIT);
break;
}
}
}
return sortedTranItems;
}
/**
* 构建FP-Tree
* @return
*/
public TreeNode generateFPTree() {
map.clear();
TreeNode linked;
root = new TreeNode(null, null);
root.setChildren(new ArrayList<FPTree.TreeNode>());
getF1Item();
sortedList = sortMapByValue(f1Items);
for (String tran : transList) {
List<String> sortedTran = getSortedTrans(tran, sortedList);
TreeNode currentRoot = root;
for (int i = 0; i < sortedTran.size(); i++) {
boolean flag = false;
for (TreeNode node : currentRoot.getChildren()) {
if (node.key.equals(sortedTran.get(i))) {
flag = true;
currentRoot = node;
break;
}
}
if (flag) {
currentRoot.value = currentRoot.value + 1;
} else {
TreeNode node = new TreeNode(sortedTran.get(i), 1);
linked = map.get(sortedTran.get(i));
if (linked != null) {
node.linked = linked;
}
map.put(sortedTran.get(i), node);
TreeNode child = null;
currentRoot.addChild(node);
for (int index = i + 1; index < sortedTran.size(); index++) {
child = new TreeNode(sortedTran.get(index), 1);
node.addChild(child);
node = child;
linked = map.get(sortedTran.get(index));
if (linked != null) {
node.linked = linked;
}
map.put(sortedTran.get(index), node);
}
break;
}
}
}
return root;
}
/**
* 存储由FP-Tree产生频繁项集时的中间结果
* @author jeff
*
*/
public class TMP {
String item;
Integer count;
Set<String> items;
public TMP(String item) {
this.item = item;
this.count = 0;
this.items = new HashSet<String>();
}
}
/**
* 产生频繁项集
* @return
*/
public Map<String, Integer> getFItems() {
Map<String, Integer> fItems = new HashMap<String, Integer>();
generateFPTree();
fItems.putAll(f1Items);
for (int i = sortedList.size() - 1; i >= 0; i--) {
String item = sortedList.get(i);
TreeNode node = map.get(item);
TreeNode parent = null;
List<TMP> cItems = new ArrayList<FPTree.TMP>();
while (node != null) {
parent = node;
TMP tmp = new TMP(item);
cItems.add(tmp);
tmp.count = parent.value;
while (parent != null && parent.key != null) {
//
if (!item.equals(parent.key)) {
tmp.items.add(parent.key);
}
parent = parent.parent;
}
node = node.linked;
}
List<TMP> res = new ArrayList<FPTree.TMP>();
for (int k = 0; k < cItems.size(); k++) {
if (cItems.get(k).count >= F) {
TMP tmp = new TMP(cItems.get(k).item);
tmp.items.addAll(cItems.get(k).items);
tmp.count = cItems.get(k).count;
res.add(tmp);
}
for (int l = k + 1; l < cItems.size(); l++) {
TMP tmp = new TMP(cItems.get(k).item);
tmp.items.addAll(cItems.get(k).items);
tmp.items.retainAll(cItems.get(l).items);
tmp.count = cItems.get(k).count + cItems.get(l).count;
if (tmp.items.size() > 0) {
res.add(tmp);
}
}
}
for (TMP tmp : res) {
// 需要得到集合的子集
for (Set<String> set : SetUtils.getSubset(tmp.items)) {
if (set.size() == 0) {
continue;
}
StringBuilder builder = new StringBuilder();
builder.append(tmp.item);
for (String it : set) {
builder.append(it);
}
String fitem = builder.toString();
if (fItems.containsKey(fitem)) {
fItems.put(fitem, fItems.get(fitem).intValue() >= tmp.count ? fItems.get(fitem) : tmp.count);
} else {
fItems.put(fitem, tmp.count);
}
}
}
}
return fItems;
}
/**
* 有频繁项集生成关联规则
* @return
*/
public Map<String, Double> getRules() {
Map<String, Double> rules = new HashMap<String, Double>();
Map<String, Integer> fItems = getFItems();
Set<String> result = new HashSet<String>();
for (String key1 : fItems.keySet()) {
for (String key2 : fItems.keySet()) {
Set<String> union1 = new HashSet<String>();
Set<String> union2 = new HashSet<String>();
union1.addAll(Arrays.asList(key1.split(SPLIT)));
union2.addAll(Arrays.asList(key2.split(SPLIT)));
result.clear();
result.addAll(union1);
result.retainAll(union2);
if (result.size() == 0) {
result.clear();
result.addAll(union1);
result.addAll(union2);
boolean flag = false;
for (String key3 : fItems.keySet()) {
flag = false;
for (String item : result) {
if (!key3.contains(item + SPLIT)) {
flag = true;
break;
}
}
if (!flag) {
double conf = fItems.get(key3) / (double) fItems.get(key1);
if (conf > C) {
rules.put(key1 + "->" + key2, conf);
}
}
}
}
}
}
return rules;
}
public static void main(String[] args) {
FPTree tree = new FPTree();
// tree.generateFPTree();
// TreeNode root = tree.getRoot();
// Map<String, TreeNode> map = tree.getMap();
// System.out.println(root);
// System.out.println(map);
Map<String, Integer> fItems = tree.getFItems();
for (Entry<String, Integer> item : fItems.entrySet()) {
System.out.println(item.getKey() + ": " + item.getValue());
}
Map<String, Double> rules = tree.getRules();
for (Entry<String, Double> item : rules.entrySet()) {
System.out.println(item.getKey() + ": " + item.getValue());
}
}
}