数据挖掘系列002 关联规则FP-Growth算法

关联规则用于发现交易数据中,不同商品之间的关系,这些规则反映了顾客的购买行为模式。如顾客经常在购买A商品的时候也会购买B商品,著名的“啤酒与尿布”的案例就是关联规则的成功应用案例

导语

Apriori算法是关联规则的基本算法,很多用于发现关联规则的算法都是基于Apriori算法,但Apriori算法需要多次访问数据库,具有严重的性能问题。FP-Growth算法只需要两次扫描数据库,相比于Apriori减少了I/O操作,克服了Apriori算法需要多次扫描数据库的问题。本文采用如下的样例数据

	A;B;E;
	B;D;
	B;C;
	A;B;D
	A;C;
	B;C;
	A;C;
	A;B;C;E;
	A;B;C;

FP-Growth生成FP-Tree

FP-Growth算法将数据库中的频繁项集压缩到一颗频繁模式树中,同时保持了频繁项集之间的关联关系。通过对该频繁模式树进行挖掘,可以得到频繁项集。其过程如下:

  1. 第一次扫描数据库,产生频繁1项集,并对产生的频繁项集按照频数降序排列,并剪枝支持数低于阀值的元素。处理后得到L集合,
  2. 第二次扫描数据库,对数据库的每个交易事务中的项按照L集合中项出现的顺序排序,生成FP-Tree。

image

产生fp-tree的步骤可以分解为如下:

image

从FP-Tree挖掘频繁项集

从FP-Tree重可以挖掘出频繁项集,其过程如下:

image

从频繁1项集链表中按照逆序开始,链表可以追溯到每个具有相同项的节点。

  1. 从链表中找到项“E”,追溯出FP-Tree中有两个带“E”的节点,由这两个节点分别向上(parent)追溯,形成两条模式:<E,C,A,B;1>,<E,A,B;1>.
  2. 由这两条模式得到项“E”的条件模式<A,B;2>.
  3. 根据条件模式,得到项“E”的频繁项集(不包含频繁1项集):<E,A;2>,<E,B;2>,<E,A,B;2>
  4. 然后一次得到项“D”,“C”,“A”。

FP-Growth算法简单实现

下面是该算法的一个简单的实现(其中由FP-Tree挖掘频繁项模式实现比较繁琐,计算关联规则时可以通过FP-Tree可以得到,但本文是通过得到的频繁项集得到的)。

package association;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import set.SetUtils;

/**
 * 
 * @author jeff
 * 
 */
public class FPTree {
	public static String SPLIT = ";";
	public static int F = 2;
	public static double C = 0.7;
	public static List<String> transList = new ArrayList<String>();
	Map<String, TreeNode> map = new HashMap<String, FPTree.TreeNode>();// 用来
	TreeNode root = null;
	Map<String, Integer> f1Items;
	List<String> sortedList;

	public Map<String, TreeNode> getMap() {
		return map;
	}

	public TreeNode getRoot() {
		return root;
	}

	static {
		transList.add("A;B;E;");
		transList.add("B;D;");
		transList.add("B;C;");
		transList.add("A;B;D;");
		transList.add("A;C;");
		transList.add("B;C;");
		transList.add("A;C;");
		transList.add("A;B;C;E;");
		transList.add("A;B;C;");
	}

	/**
	 * FP-Tree的节点结构
	 * @author jeff
	 *
	 */
	public class TreeNode {
		public String key;
		public Integer value;

		public TreeNode(String key, Integer value) {
			this.key = key;
			this.value = value;
			this.children = new ArrayList<FPTree.TreeNode>();
		}

		public TreeNode getParent() {
			return this.parent;
		}

		public void setParent(TreeNode parent) {
			this.parent = parent;
		}

		public List<TreeNode> getChildren() {
			return this.children;
		}

		public void setChildren(List<TreeNode> children) {
			this.children = children;
		}

		public void addChild(TreeNode child) {
			if (this.children == null) {
				this.children = new ArrayList<FPTree.TreeNode>();
			}
			this.children.add(child);
			child.setParent(this);
		}

		public TreeNode getLinked() {
			return this.linked;
		}

		public void setLinked(TreeNode linked) {
			this.linked = linked;
		}

		public TreeNode parent;
		public List<TreeNode> children;
		public TreeNode linked;
	}

	/**
	 * 扫描数据库得到频繁1项集
	 * @return
	 */
	private Map<String, Integer> getF1Item() {
		f1Items = new HashMap<String, Integer>();
		for (String tran : transList) {
			String[] items = tran.split(SPLIT);
			for (String item : items) {
				if (f1Items.get(item + SPLIT) == null) {
					f1Items.put(item + SPLIT, 1);
				} else {
					f1Items.put(item + SPLIT, f1Items.get(item + SPLIT) + 1);
				}
			}
		}
		Map<String, Integer> res = new HashMap<String, Integer>();
		for (Map.Entry<String, Integer> entry : f1Items.entrySet()) {
			if (entry.getValue() >= F) {
				res.put(entry.getKey(), entry.getValue());
			}
		}
		return res;
	}

	/**
	 * 对频繁1项集按照频次进行排序
	 * @param oriMap
	 * @return
	 */
	private List<String> sortMapByValue(Map<String, Integer> oriMap) {
		List<String> sortedNodes = new ArrayList<String>();
		if (oriMap != null && !oriMap.isEmpty()) {
			List<Map.Entry<String, Integer>> entryList = new ArrayList<Map.Entry<String, Integer>>(oriMap.entrySet());
			Collections.sort(entryList, new Comparator<Map.Entry<String, Integer>>() {
				public int compare(Entry<String, Integer> entry1, Entry<String, Integer> entry2) {
					int value1 = 0, value2 = 0;
					try {
						value1 = entry1.getValue();
						value2 = entry2.getValue();
					} catch (NumberFormatException e) {
						value1 = 0;
						value2 = 0;
					}
					return value2 - value1;
				}
			});
			Iterator<Map.Entry<String, Integer>> iter = entryList.iterator();
			Map.Entry<String, Integer> tmpEntry = null;
			while (iter.hasNext()) {
				tmpEntry = iter.next();
				sortedNodes.add(tmpEntry.getKey());
			}
		}
		return sortedNodes;
	}

	/**
	 * 根据排序后的频繁1项集,对事务中包含的项集进行排序
	 * @param tran
	 * @param sortedItems
	 * @return
	 */
	private List<String> getSortedTrans(String tran, List<String> sortedItems) {
		List<String> sortedTranItems = new ArrayList<String>();
		String[] items = tran.split(SPLIT);
		for (String item : sortedItems) {
			for (String tItem : items) {
				if (item.equals(tItem + SPLIT)) {
					sortedTranItems.add(tItem + SPLIT);
					break;
				}
			}
		}
		return sortedTranItems;
	}

	/**
	 * 构建FP-Tree
	 * @return
	 */
	public TreeNode generateFPTree() {
		map.clear();
		TreeNode linked;

		root = new TreeNode(null, null);
		root.setChildren(new ArrayList<FPTree.TreeNode>());
		getF1Item();
		sortedList = sortMapByValue(f1Items);
		for (String tran : transList) {
			List<String> sortedTran = getSortedTrans(tran, sortedList);
			TreeNode currentRoot = root;
			for (int i = 0; i < sortedTran.size(); i++) {
				boolean flag = false;
				for (TreeNode node : currentRoot.getChildren()) {
					if (node.key.equals(sortedTran.get(i))) {
						flag = true;
						currentRoot = node;
						break;
					}
				}
				if (flag) {
					currentRoot.value = currentRoot.value + 1;
				} else {
					TreeNode node = new TreeNode(sortedTran.get(i), 1);
					linked = map.get(sortedTran.get(i));
					if (linked != null) {
						node.linked = linked;
					}
					map.put(sortedTran.get(i), node);

					TreeNode child = null;
					currentRoot.addChild(node);
					for (int index = i + 1; index < sortedTran.size(); index++) {
						child = new TreeNode(sortedTran.get(index), 1);
						node.addChild(child);
						node = child;
						linked = map.get(sortedTran.get(index));
						if (linked != null) {
							node.linked = linked;
						}
						map.put(sortedTran.get(index), node);
					}
					break;
				}
			}

		}

		return root;
	}

	/**
	 * 存储由FP-Tree产生频繁项集时的中间结果
	 * @author jeff
	 *
	 */
	public class TMP {
		String item;
		Integer count;
		Set<String> items;

		public TMP(String item) {
			this.item = item;
			this.count = 0;
			this.items = new HashSet<String>();
		}
	}

	/**
	 * 产生频繁项集
	 * @return
	 */
	public Map<String, Integer> getFItems() {
		Map<String, Integer> fItems = new HashMap<String, Integer>();
		generateFPTree();
		fItems.putAll(f1Items);
		for (int i = sortedList.size() - 1; i >= 0; i--) {
			String item = sortedList.get(i);
			TreeNode node = map.get(item);
			TreeNode parent = null;
			List<TMP> cItems = new ArrayList<FPTree.TMP>();
			while (node != null) {
				parent = node;
				TMP tmp = new TMP(item);
				cItems.add(tmp);
				tmp.count = parent.value;
				while (parent != null && parent.key != null) {
					//
					if (!item.equals(parent.key)) {
						tmp.items.add(parent.key);
					}
					parent = parent.parent;
				}
				node = node.linked;
			}
			List<TMP> res = new ArrayList<FPTree.TMP>();
			for (int k = 0; k < cItems.size(); k++) {
				if (cItems.get(k).count >= F) {
					TMP tmp = new TMP(cItems.get(k).item);
					tmp.items.addAll(cItems.get(k).items);
					tmp.count = cItems.get(k).count;
					res.add(tmp);
				}
				for (int l = k + 1; l < cItems.size(); l++) {
					TMP tmp = new TMP(cItems.get(k).item);
					tmp.items.addAll(cItems.get(k).items);
					tmp.items.retainAll(cItems.get(l).items);
					tmp.count = cItems.get(k).count + cItems.get(l).count;
					if (tmp.items.size() > 0) {
						res.add(tmp);
					}
				}
			}

			for (TMP tmp : res) {
				// 需要得到集合的子集
				for (Set<String> set : SetUtils.getSubset(tmp.items)) {
					if (set.size() == 0) {
						continue;
					}
					StringBuilder builder = new StringBuilder();
					builder.append(tmp.item);
					for (String it : set) {
						builder.append(it);
					}
					String fitem = builder.toString();
					if (fItems.containsKey(fitem)) {
						fItems.put(fitem, fItems.get(fitem).intValue() >= tmp.count ? fItems.get(fitem) : tmp.count);
					} else {
						fItems.put(fitem, tmp.count);
					}

				}
			}
		}

		return fItems;
	}

	/**
	 * 有频繁项集生成关联规则
	 * @return
	 */
	public Map<String, Double> getRules() {
		Map<String, Double> rules = new HashMap<String, Double>();
		Map<String, Integer> fItems = getFItems();

		Set<String> result = new HashSet<String>();
		for (String key1 : fItems.keySet()) {
			for (String key2 : fItems.keySet()) {
				Set<String> union1 = new HashSet<String>();
				Set<String> union2 = new HashSet<String>();

				union1.addAll(Arrays.asList(key1.split(SPLIT)));
				union2.addAll(Arrays.asList(key2.split(SPLIT)));

				result.clear();
				result.addAll(union1);
				result.retainAll(union2);
				if (result.size() == 0) {
					result.clear();
					result.addAll(union1);
					result.addAll(union2);
					boolean flag = false;
					for (String key3 : fItems.keySet()) {
						flag = false;
						for (String item : result) {
							if (!key3.contains(item + SPLIT)) {
								flag = true;
								break;
							}
						}
						if (!flag) {
							double conf = fItems.get(key3) / (double) fItems.get(key1);
							if (conf > C) {
								rules.put(key1 + "->" + key2, conf);
							}
						}
					}
				}
			}
		}

		return rules;

	}

	public static void main(String[] args) {
		FPTree tree = new FPTree();
		// tree.generateFPTree();
		// TreeNode root = tree.getRoot();
		// Map<String, TreeNode> map = tree.getMap();
		// System.out.println(root);
		// System.out.println(map);
		Map<String, Integer> fItems = tree.getFItems();
		for (Entry<String, Integer> item : fItems.entrySet()) {
			System.out.println(item.getKey() + ":   " + item.getValue());
		}

		Map<String, Double> rules = tree.getRules();
		for (Entry<String, Double> item : rules.entrySet()) {
			System.out.println(item.getKey() + ":   " + item.getValue());
		}
	}

}
Jeff Lee /
Published under (CC) BY-NC-SA in categories data mining  tagged with 关联规则  FP-Growth算法