Apache Mahout пользовательская реализация DataModel - PullRequest
0 голосов
/ 03 декабря 2018

У меня есть пользовательская реализация DataModel с именем ListDataModel

package com.recommender.models;


import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveArrayIterator;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.model.GenericItemPreferenceArray;
import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.PreferenceArray;

import com.recommender.util.utils;

public class ListDataModel<T> implements DataModel {

    /**
     * 
     */
    private static final long serialVersionUID = 1L;
    private List<T> __lst = new ArrayList<T>();
    private String __userColumnName;
    private String __itemColumnName;
    private String __referenceColumnName;

    public ListDataModel(List<T> list, String userColumnName, String itemColumnName,String refColumnName) {
        __lst = list;
        __userColumnName = userColumnName;
        __itemColumnName = itemColumnName;
        __referenceColumnName = refColumnName;
    }

    @SuppressWarnings("unchecked")
    @Override
    public void refresh(Collection<Refreshable> arg0) {
        __lst =  arg0.stream().map(m->(T)m).collect(Collectors.toList());
    }

    @Override
    public LongPrimitiveIterator getItemIDs() throws TasteException {
        List<Long> list = __lst.stream().map((r) -> utils.GetLong(r, __itemColumnName)).distinct().collect(Collectors.toList());
        return new LongPrimitiveArrayIterator(utils.tolongArray(list));
    }

    @Override
    public FastIDSet getItemIDsFromUser(long arg0) throws TasteException {
        List<Long> list = __lst.stream().filter(fil->utils.GetLong(fil, __userColumnName)==arg0).map((r) -> utils.GetLong(r, __itemColumnName)).distinct().collect(Collectors.toList());
        return new FastIDSet(utils.tolongArray(list));
    }

    @Override
    public float getMaxPreference() {
    //  return utils.GetFloat(__lst.stream().max(new ListComprator<>(__referenceColumnName)).get(),__referenceColumnName);
        return Float.NaN;
    }

    @Override
    public float getMinPreference() {
//      return utils.GetFloat(__lst.stream().min(new ListComprator<>(__referenceColumnName)).get(),__referenceColumnName);
        return Float.NaN;
    }

    @Override
    public int getNumItems() throws TasteException {
        Set<Long> set = new HashSet<>(__lst.size());
        return __lst.stream().filter(p -> set.add(utils.GetLong(p, __itemColumnName))).distinct().collect(Collectors.toList()).size();
    }

    @Override
    public int getNumUsers() throws TasteException {
        return __lst.stream().map(p -> utils.GetLong(p, __userColumnName)).distinct().collect(Collectors.toList()).size();
    }

    @Override
    public int getNumUsersWithPreferenceFor(long arg0) throws TasteException {
        return (int) __lst.stream().filter(fil->utils.GetLong(fil, __itemColumnName)==arg0).count();
    }

    @Override
    public int getNumUsersWithPreferenceFor(long arg0, long arg1) throws TasteException {
        return (int) __lst.stream().filter(fil->utils.GetLong(fil, __itemColumnName)==arg0 && utils.GetLong(fil, __itemColumnName)==arg0 ).count();
    }

    @Override
    public Long getPreferenceTime(long arg0, long arg1) throws TasteException {
        return 0L;
    }

    @Override
    public Float getPreferenceValue(long arg0, long arg1) throws TasteException {
        List<T> list = __lst.stream().filter(fil->utils.GetLong(fil, __userColumnName)==arg0 && utils.GetLong(fil, __itemColumnName)==arg1 ).collect(Collectors.toList());
        if(list.size()==0) return Float.NaN;
        return utils.GetFloat(list.get(0),__referenceColumnName);
    }

    @Override
    public PreferenceArray getPreferencesForItem(long arg0) throws TasteException {
        return new GenericItemPreferenceArray(__lst.stream()
                                                                  .filter(fil->utils.GetLong(fil, __itemColumnName)==arg0)
                                                                  .map(m->new Preference<T>(m, __userColumnName, __itemColumnName, __referenceColumnName))
                                                                  .collect(Collectors.toList()));
    }

    @Override
    public PreferenceArray getPreferencesFromUser(long arg0) throws TasteException {
        return  new GenericUserPreferenceArray(__lst.stream()
                  .filter(fil->utils.GetLong(fil, __userColumnName)==arg0)
                  .map(m->new Preference<T>(m, __userColumnName, __itemColumnName, __referenceColumnName))
                  .collect(Collectors.toList()));
    }

    @Override
    public LongPrimitiveIterator getUserIDs() throws TasteException {
        List<Long> list = __lst.stream().map((r) -> utils.GetLong(r, __userColumnName)).distinct().collect(Collectors.toList());
        return new LongPrimitiveArrayIterator(utils.tolongArray(list));
    }

    @Override
    public boolean hasPreferenceValues() {
        return true;
    }

    @Override
    public void removePreference(long arg0, long arg1) throws TasteException {
        __lst =__lst.stream().filter(fil->utils.GetLong(fil, __userColumnName)!=arg0 && utils.GetLong(fil, __itemColumnName)!=arg0).collect(Collectors.toList());
    }

    @Override
    public void setPreference(long arg0, long arg1, float arg2) throws TasteException {
        List<T> list = __lst.stream().filter(fil->utils.GetLong(fil, __userColumnName)==arg0 && utils.GetLong(fil, __itemColumnName)==arg0).collect(Collectors.toList());
        if(list.size()==0) return;
        utils.SetFloatValue(__lst.get(__lst.indexOf(list.get(0))),__referenceColumnName,arg2);
    }

}

.

package com.recommender.util;

import java.lang.reflect.Field;
import java.util.List;

public class utils {
        public static long[] tolongArray(List<Long> list) {
        long[] items = new long[list.size()];
        for (int i = 0; i < list.size(); i++) {
            items[i]=list.get(i);
        }
        return items;
    }
    public static Long GetLong(Object obj,String name) {
        String str = "";
        try {
            return Long.parseLong(getFiled(obj, name).get(obj).toString());
        } catch (IllegalArgumentException | IllegalAccessException  | SecurityException e) {
            System.out.println(str);
            System.out.println("Some error in getting Long from Field(com.recommender.util.GetLong)");
            e.printStackTrace();
        }
        return 0L;
    }
    public static String GetString(Object obj,String name) {
        try {
            return getFiled(obj, name).get(obj).toString();
        } catch (IllegalArgumentException | IllegalAccessException  | SecurityException e) {
            System.out.println("Some error in getting String from Field(com.recommender.util.GetString)");
            e.printStackTrace();
        }
        return "";
    }
    public static Float GetFloat(Object obj,String name) {
        try {
            return Float.parseFloat(getFiled(obj, name).get(obj).toString());
        } catch (IllegalArgumentException | IllegalAccessException  | SecurityException e) {
            System.out.println("Some error in getting String from Field(com.recommender.util.GetString)");
            e.printStackTrace();
        }
        return 0F;
    }


    public static void SetFloatValue(Object obj,String name,Float value) {
        try {
            getFiled(obj, name).setFloat(obj, value);
        } catch (IllegalArgumentException | IllegalAccessException e) {
            System.out.println("Some error in getting String from Field(com.recommender.util.SetFloatValue)");
            e.printStackTrace();
        } 
    }
    public static void SetLongValue(Object obj,String name,Long value) {
        try {
            getFiled(obj, name).setLong(obj, value);
        } catch (IllegalArgumentException | IllegalAccessException e) {
            System.out.println("Some error in getting String from Field(com.recommender.util.SetLongValue)");
            e.printStackTrace();
        } 
    }
    private static Field getFiled(Object obj,String name) {
        try {
            return obj.getClass().getField(name);
        } catch (NoSuchFieldException | SecurityException e) {
            System.out.println("Some error in getting Filed(com.recommender.util.getFiled)");
            e.printStackTrace();
        }
        return null;
    }
}

все выглядит хорошо, но когда я хочу использовать эту DataModel, чтобы взятьneighborhoods, ничего не возвращается, и приложение никогда не останавливается

package com.recommender.models;

public class Ratings {
    public Long UserId;
    public Long ISBN;
    public Long Rate;
    public Ratings(Long ui,Long isbn,Long rate) {
        UserId = ui;
        ISBN = isbn;
        Rate = rate;
    }
    public String toString() {
        return UserId.toString()+","+ISBN.toString()+","+Rate.toString();
    }
}

.

  DataModel dm = new ListDataModel<Ratings>(ratings, "UserId", "ISBN", "Rate");

  UserSimilarity sim = Users.UncenteredCosineSimilarity(dm);

  UserNeighborhood neighborhood = new NearestNUserNeighborhood(10,.1, sim, dm); 
//when reaches here App never stops and nothing return
...