import java.util.*; /** * Daniel Lemire * A simple implementation of the weighted slope one * algorithm in Java for item-based collaborative * filtering. * Assumes Java 1.5. * * See main function for example. * * June 1st 2006. * Revised by Marco Ponzi on March 29th 2007 */ public class SlopeOne { public static void main(String args[]){ // this is my data base Map> data = new HashMap>(); // items ItemId item1 = new ItemId(" candy"); ItemId item2 = new ItemId(" dog"); ItemId item3 = new ItemId(" cat"); ItemId item4 = new ItemId(" war"); ItemId item5 = new ItemId("strange food"); mAllItems = new ItemId[]{item1, item2, item3, item4, item5}; //I'm going to fill it in HashMap user1 = new HashMap(); HashMap user2 = new HashMap(); HashMap user3 = new HashMap(); HashMap user4 = new HashMap(); user1.put(item1,1.0f); user1.put(item2,0.5f); user1.put(item4,0.1f); data.put(new UserId("Bob"),user1); user2.put(item1,1.0f); user2.put(item3,0.5f); user2.put(item4,0.2f); data.put(new UserId("Jane"),user2); user3.put(item1,0.9f); user3.put(item2,0.4f); user3.put(item3,0.5f); user3.put(item4,0.1f); data.put(new UserId("Jo"),user3); user4.put(item1,0.1f); //user4.put(item2,0.4f); //user4.put(item3,0.5f); user4.put(item4,1.0f); user4.put(item5,0.4f); data.put(new UserId("StrangeJo"),user4); // next, I create my predictor engine SlopeOne so = new SlopeOne(data); System.out.println("Here's the data I have accumulated..."); so.printData(); // then, I'm going to test it out... HashMap user = new HashMap(); System.out.println("Ok, now we predict..."); user.put(item5,0.4f); System.out.println("Inputting..."); SlopeOne.print(user); System.out.println("Getting..."); SlopeOne.print(so.predict(user)); // user.put(item4,0.2f); System.out.println("Inputting..."); SlopeOne.print(user); System.out.println("Getting..."); SlopeOne.print(so.predict(user)); } Map> mData; Map> mDiffMatrix; Map> mFreqMatrix; static ItemId[] mAllItems; public SlopeOne(Map> data) { mData = data; buildDiffMatrix(); } /** * Based on existing data, and using weights, * try to predict all missing ratings. * The trick to make this more scalable is to consider * only mDiffMatrix entries having a large (>1) mFreqMatrix * entry. * * It will output the prediction 0 when no prediction is possible. */ public Map predict(Map user) { HashMap predictions = new HashMap(); HashMap frequencies = new HashMap(); for (ItemId j : mDiffMatrix.keySet()) { frequencies.put(j,0); predictions.put(j,0.0f); } for (ItemId j : user.keySet()) { for (ItemId k : mDiffMatrix.keySet()) { try { float newval = ( mDiffMatrix.get(k).get(j).floatValue() + user.get(j).floatValue() ) * mFreqMatrix.get(k).get(j).intValue(); predictions.put(k, predictions.get(k)+newval); frequencies.put(k, frequencies.get(k)+mFreqMatrix.get(k).get(j).intValue()); } catch(NullPointerException e) {} } } HashMap cleanpredictions = new HashMap(); for (ItemId j : predictions.keySet()) { if (frequencies.get(j)>0) { cleanpredictions.put(j, predictions.get(j).floatValue()/frequencies.get(j).intValue()); } } for (ItemId j : user.keySet()) { cleanpredictions.put(j,user.get(j)); } return cleanpredictions; } /** * Based on existing data, and not using weights, * try to predict all missing ratings. * The trick to make this more scalable is to consider * only mDiffMatrix entries having a large (>1) mFreqMatrix * entry. */ public Map weightlesspredict(Map user) { HashMap predictions = new HashMap(); HashMap frequencies = new HashMap(); for (ItemId j : mDiffMatrix.keySet()) { predictions.put(j,0.0f); frequencies.put(j,0); } for (ItemId j : user.keySet()) { for (ItemId k : mDiffMatrix.keySet()) { //System.out.println("Average diff between "+j+" and "+ k + " is "+mDiffMatrix.get(k).get(j).floatValue()+" with n = "+mFreqMatrix.get(k).get(j).floatValue()); float newval = ( mDiffMatrix.get(k).get(j).floatValue() + user.get(j).floatValue() ) ; predictions.put(k, predictions.get(k)+newval); } } for (ItemId j : predictions.keySet()) { predictions.put(j, predictions.get(j).floatValue()/user.size()); } for (ItemId j : user.keySet()) { predictions.put(j,user.get(j)); } return predictions; } public void printData() { for(UserId user : mData.keySet()) { System.out.println(user); print(mData.get(user)); } for (int i=0; i ratings, Map frequencies) { for (int j=0; j user) { for (ItemId j : user.keySet()) { System.out.println(" "+ j+ " --> "+user.get(j).floatValue()); } } public void buildDiffMatrix() { mDiffMatrix = new HashMap>(); mFreqMatrix = new HashMap>(); // first iterate through users for(Map user : mData.values()) { // then iterate through user data for(Map.Entry entry: user.entrySet()) { if(!mDiffMatrix.containsKey(entry.getKey())) { mDiffMatrix.put(entry.getKey(), new HashMap()); mFreqMatrix.put(entry.getKey(), new HashMap()); } for(Map.Entry entry2: user.entrySet()) { int oldcount = 0; if(mFreqMatrix.get(entry.getKey()).containsKey(entry2.getKey())) oldcount = mFreqMatrix.get(entry.getKey()).get(entry2.getKey()).intValue(); float olddiff = 0.0f; if(mDiffMatrix.get(entry.getKey()).containsKey(entry2.getKey())) olddiff = mDiffMatrix.get(entry.getKey()).get(entry2.getKey()).floatValue(); float observeddiff = entry.getValue() - entry2.getValue(); mFreqMatrix.get(entry.getKey()).put(entry2.getKey(),oldcount + 1); mDiffMatrix.get(entry.getKey()).put(entry2.getKey(),olddiff+observeddiff); } } } for (ItemId j : mDiffMatrix.keySet()) { for (ItemId i : mDiffMatrix.get(j).keySet()) { float oldvalue = mDiffMatrix.get(j).get(i).floatValue(); int count = mFreqMatrix.get(j).get(i).intValue(); mDiffMatrix.get(j).put(i,oldvalue/count); } } } } class UserId { String content; public UserId(String s) { content = s; } public int hashCode() { return content.hashCode();} public String toString() { return content; } } class ItemId { String content; public ItemId(String s) { content = s; } public int hashCode() { return content.hashCode();} public String toString() { return content; } }