Java-Stream流规约、分组、分区

问题引入

如果你在做一些汇总操作,比如

  • 1、对一个交易列表按货币分组,获取每种货币的和(Map<Cruuency,Integer>)
  • 2、将交易分成贵的、不贵的(Map<Boolean,List<Transaction>>)
  • 3、多级分组,按城市分组,再按贵和不贵分组

如果是传统的写法,使用外部迭代即可,会有很多for+if组合,类似:

 

private static void groupImperatively() {
        Map<Currency, List<Transaction>> transactionsByCurrencies = new HashMap<>();
        for (Transaction transaction : transactions) {
            Currency currency = transaction.getCurrency();
            List<Transaction> transactionsForCurrency = transactionsByCurrencies.get(currency);
            if (transactionsForCurrency == null) {
                    transactionsForCurrency = new ArrayList<>();
                transactionsByCurrencies.put(currency, transactionsForCurrency);
            }
            transactionsForCurrency.add(transaction);
        }

        System.out.println(transactionsByCurrencies);
    }

 

而使用Stream,可以用Collectors收集器:

Map<Currency, List<Transaction>> transactionsByCurrencies = transactions.stream().collect(groupingBy(Transaction::getCurrency));

这个类提供很多工厂方法,主要分3类

  • 1、规约、汇总;
  • 2、分组
  • 3、分区

使用数据

 

import java.util.*;

public class Dish {

    private final String name;
    private final boolean vegetarian;
    private final int calories;
    private final Type type;

    public Dish(String name, boolean vegetarian, int calories, Type type) {
        this.name = name;
        this.vegetarian = vegetarian;
        this.calories = calories;
        this.type = type;
    }

    public String getName() {
        return name;
    }

    public boolean isVegetarian() {
        return vegetarian;
    }

    public int getCalories() {
        return calories;
    }

    public Type getType() {
        return type;
    }

    public enum Type { MEAT, FISH, OTHER }

    @Override
    public String toString() {
        return name;
    }

    public static final List<Dish> menu =
            Arrays.asList( new Dish("pork", false, 800, Dish.Type.MEAT),
                           new Dish("beef", false, 700, Dish.Type.MEAT),
                           new Dish("chicken", false, 400, Dish.Type.MEAT),
                           new Dish("french fries", true, 530, Dish.Type.OTHER),
                           new Dish("rice", true, 350, Dish.Type.OTHER),
                           new Dish("season fruit", true, 120, Dish.Type.OTHER),
                           new Dish("pizza", true, 550, Dish.Type.OTHER),
                           new Dish("prawns", false, 400, Dish.Type.FISH),
                           new Dish("salmon", false, 450, Dish.Type.FISH));
}

 

规约汇总

先统计之前例子里的数据,统计一共有多少菜。

menu.stream().collect(counting());

找出卡路里的最大和最小值

private static Dish findMostCaloricDishUsingComparator() {
  Comparator<Dish> dishCaloriesComparator = Comparator.comparingInt(Dish::getCalories);
  BinaryOperator<Dish> moreCaloricOf = BinaryOperator.maxBy(dishCaloriesComparator);
  return menu.stream().collect(reducing(moreCaloricOf)).get();
}

也可以用reduce来取最大最小值,推荐用法

private static Dish findMostCaloricDish() {
  return menu.stream().collect(reducing((d1, d2) -> d1.getCalories() > d2.getCalories() ? d1 : d2)).get();
}

汇总

private static int calculateTotalCalories() {
  return menu.stream().collect(summingInt(Dish::getCalories));
}

平均数

private static Double calculateAverageCalories() {
  return menu.stream().collect(averagingInt(Dish::getCalories));
}

 

一次性获取最大、最小、平均、和

private static IntSummaryStatistics calculateMenuStatistics() {
    return menu.stream().collect(summarizingInt(Dish::getCalories));
}

结果
Menu statistics: IntSummaryStatistics{count=9, sum=4300, min=120, average=477.777778, max=800}

连接字符串

 

private static String getShortMenu() {
    return menu.stream().map(Dish::getName).collect(joining());
}

private static String getShortMenuCommaSeparated() {
    return menu.stream().map(Dish::getName).collect(joining(", "));
}

 

广义规约reduce

以上的写法都是通过reduce来实现的,统统可以用reduce来写,比如总计

 

 //总计,Lambda方式
    private static int calculateTotalCalories() {
        return menu.stream().collect(reducing(0, Dish::getCalories, (Integer i, Integer j) -> i + j));
    }

    //使用方法引用来总计
    private static int calculateTotalCaloriesWithMethodReference() {
        return menu.stream().collect(reducing(0, Dish::getCalories, Integer::sum));
    }

    //不用Collectors的汇总
    private static int calculateTotalCaloriesWithoutCollectors() {
        return menu.stream().map(Dish::getCalories).reduce(Integer::sum).get();
    }

    //IntStream方式
    private static int calculateTotalCaloriesUsingSum() {
        return menu.stream().mapToInt(Dish::getCalories).sum();
    }

 

以上的方式,IntStream最好,一是比较直观,二是没有Integer的装箱,性能最佳。

分组groupingBy

也叫分类,使用groupingBy方法,参数是Function方法引用,也是分类函数,分组的输出一个map,key就是类型

定义:

 

    public static <T, K> Collector<T, ?, Map<K, List<T>>>
    groupingBy(Function<? super T, ? extends K> classifier) {
        return groupingBy(classifier, toList());
    }

    public static <T, K, A, D>
    Collector<T, ?, Map<K, D>> groupingBy(Function<? super T, ? extends K> classifier,
                                          Collector<? super T, A, D> downstream) {
        return groupingBy(classifier, HashMap::new, downstream);
    }

例子

 

//单层分类
    private static Map<Dish.Type, List<Dish>> groupDishesByType() {
        return menu.stream().collect(groupingBy(Dish::getType));
    }

    //单层自定义分类
    private static Map<CaloricLevel, List<Dish>> groupDishesByCaloricLevel() {
        return menu.stream().collect(
                groupingBy(dish -> {
                    if (dish.getCalories() <= 400) return CaloricLevel.DIET;
                    else if (dish.getCalories() <= 700) return CaloricLevel.NORMAL;
                    else return CaloricLevel.FAT;
                } ));
    }

    //2层分类,第一层是类型,第二层是卡路里级别
    private static Map<Dish.Type, Map<CaloricLevel, List<Dish>>> groupDishedByTypeAndCaloricLevel() {
        return menu.stream().collect(
                groupingBy(Dish::getType,
                        groupingBy((Dish dish) -> {
                            if (dish.getCalories() <= 400) return CaloricLevel.DIET;
                            else if (dish.getCalories() <= 700) return CaloricLevel.NORMAL;
                            else return CaloricLevel.FAT;
                        } )
                )
        );
    }

    //子分组计数
    private static Map<Dish.Type, Long> countDishesInGroups() {
        return menu.stream().collect(groupingBy(Dish::getType, counting()));
    }
    //子分组取最大值
    private static Map<Dish.Type, Optional<Dish>> mostCaloricDishesByType() {
        return menu.stream().collect(
                groupingBy(Dish::getType,
                        reducing((Dish d1, Dish d2) -> d1.getCalories() > d2.getCalories() ? d1 : d2)));
    }
    //不使用Optional
    private static Map<Dish.Type, Dish> mostCaloricDishesByTypeWithoutOprionals() {
        return menu.stream().collect(
                groupingBy(Dish::getType,
                        collectingAndThen(
                                reducing((d1, d2) -> d1.getCalories() > d2.getCalories() ? d1 : d2),
                                Optional::get)));
    }
    //子组汇总
    private static Map<Dish.Type, Integer> sumCaloriesByType() {
        return menu.stream().collect(groupingBy(Dish::getType,
                summingInt(Dish::getCalories)));
    }
    //分组自定义转换
    private static Map<Dish.Type, Set<CaloricLevel>> caloricLevelsByType() {
        return menu.stream().collect(
                groupingBy(Dish::getType, mapping(
                        dish -> { if (dish.getCalories() <= 400) return CaloricLevel.DIET;
                        else if (dish.getCalories() <= 700) return CaloricLevel.NORMAL;
                        else return CaloricLevel.FAT; },
                        toSet() )));
    }

 

 

分区partitioningBy

分区就是区分"是"or"非"的分组,分区里可以嵌套分组,定义

 

private static Map<Boolean, List<Dish>> partitionByVegeterian() {
        return menu.stream().collect(partitioningBy(Dish::isVegetarian));
    }

    private static Map<Boolean, Map<Dish.Type, List<Dish>>> vegetarianDishesByType() {
        return menu.stream().collect(partitioningBy(Dish::isVegetarian, groupingBy(Dish::getType)));
    }

    private static Object mostCaloricPartitionedByVegetarian() {
        return menu.stream().collect(
                partitioningBy(Dish::isVegetarian,
                        collectingAndThen(
                                maxBy(comparingInt(Dish::getCalories)),
                                Optional::get)));
    }

 

结果:

Dishes partitioned by vegetarian: {false=[pork, beef, chicken, prawns, salmon], true=[french fries, rice, season fruit, pizza]}
Vegetarian Dishes by type: {false={FISH=[prawns, salmon], MEAT=[pork, beef, chicken]}, true={OTHER=[french fries, rice, season fruit, pizza]}}
Most caloric dishes by vegetarian: {false=pork, true=pizza}

分区的例子,求质数(素数)

即对于大于1的数,如果除了1和它本身,它不能再被其它正整数整除,那么我们说它是一个质数。
传统的判断一个数是不是质数的写法:

 

public static boolean isPrimeNormal(int num) {
    for(int i=2; i<num; i++) {
        if(num%i == 0) {
            return false;
        }
    }
     
    return true;
}

//优化的算法是只测试待测数平方根以下:
private static boolean isPrime(int src) {
        double sqrt = Math.sqrt(src);
        if (src < 2) {
            return false;
        }
        if (src == 2 || src == 3) {
            return true;
        }
        if (src % 2 == 0) {// 先判断是否为偶数,若偶数就直接结束程序
            return false;
        }
        for (int i = 3; i <= sqrt; i+=2) {
            if (src % i == 0) {
                return false;
            }
        }
        return true;
    }

 

Stream写法:

public static Map<Boolean, List<Integer>> partitionPrimes(int n) {
        return IntStream.rangeClosed(2, n).boxed()
                .collect(partitioningBy(candidate -> isPrime(candidate)));
    }

    public static boolean isPrime(int candidate) {
        return IntStream.rangeClosed(2, candidate-1)
                .limit((long) Math.floor(Math.sqrt((double) candidate)) - 1)
                .noneMatch(i -> candidate % i == 0);
    }

复制代码

结果

{false=[4, 6, 8, 9, 10, 12, 14, 15, 16, 18, 20, 21, 22, 24, 25, 26, 27, 28, 30, 32, 33, 34, 35, 36, 38, 39, 40, 42, 44, 45, 46, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 60, 62, 63, 64, 65, 66, 68, 69, 70, 72, 74, 75, 76, 77, 78, 80, 81, 82, 84, 85, 86, 87, 88, 90, 91, 92, 93, 94, 95, 96, 98, 99, 100], true=[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]}

Map转换ToMap 

ToMap 收集器用于收集流元素至 Map 实例,实现该功能需要提供两个函数:

  • keyMapper
  • valueMapper

keyMapper 用于从流元素中抽取map 的key,valueMapper抽取与可以关联的value。定义:

Collector<T, ?, Map<K,U>> toMap(Function<? super T, ? extends K> keyMapper,
  Function<? super T, ? extends U> valueMapper)

下面示例收集流元素至Map中,存储字符串作为key,其长度作为value:

复制代码

List<Book> bookList = new ArrayList<>();
bookList.add(new Book("The Fellowship of the Ring", 1954, "0395489318"));
bookList.add(new Book("The Two Towers", 1954, "0345339711"));
bookList.add(new Book("The Return of the King", 1955, "0618129111"));

public Map<String, String> listToMap(List<Book> books) {
    return books.stream().collect(Collectors.toMap(Book::getIsbn, Book::getName));
}

复制代码

Function.identity() 是一个预定义的返回接收参数的快捷函数。
如果我们集合中包含重复元素会怎么样?与toSet相反,toMap不能过滤重复元素。这个比较好理解————其如何确定key关联那个value?

List<String> listWithDuplicates = Arrays.asList("a", "bb", "c", "d", "bb");
assertThatThrownBy(() -> {
listWithDuplicates.stream().collect(toMap(Function.identity(), String::length));
}).isInstanceOf(IllegalStateException.class);

我们看到,toMap甚至不判断值是否相等,如果key重复,立刻抛出IllegalStateException异常。要解决这个问题,我们需要使用另一种方法和附加参数mergefunction:

Collector<T, ?, M> toMap(Function<? super T, ? extends K> keyMapper,
  Function<? super T, ? extends U> valueMapper,
  BinaryOperator<U> mergeFunction)
BinaryOperator是合并函数

实验:

public Map<Integer, Book> listToMapWithDupKey(List<Book> books) {
    return books.stream().collect(Collectors.toMap(Book::getReleaseYear, Function.identity(),
      (existing, replacement) -> existing));
}

 

THE END
< <上一篇
下一篇>>