package com.liquidnet.common.sharding.jdbc.algorithm;

import com.alibaba.fastjson.JSON;
import com.google.common.collect.Range;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingValue;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingAlgorithm;
import org.apache.shardingsphere.api.sharding.standard.RangeShardingValue;

import java.util.ArrayList;
import java.util.Collection;

/**
 * Custom implementation `PreciseShardingAlgorithm` and `RangeShardingAlgorithm`
 *
 * @author zhanggb
 * Created by IntelliJ IDEA at 2020/9/23
 */
@Slf4j
public class ModuloShardingTableAlgorithm implements PreciseShardingAlgorithm<String>, RangeShardingAlgorithm<String> {
    @Override
    public String doSharding(Collection<String> collection, PreciseShardingValue<String> preciseShardingValue) {
        log.debug("### collection:{}", JSON.toJSONString(collection));
        log.debug("### preciseShardingValue:{}", JSON.toJSONString(preciseShardingValue));
        String moduloTableName = preciseShardingValue.getLogicTableName() + (Long.parseLong(preciseShardingValue.getValue()) % collection.size());
        if (collection.contains(moduloTableName.toLowerCase())) {
            log.info("### Modulo Table: {}", moduloTableName);
            return moduloTableName;
        }
        throw new UnsupportedOperationException();
    }

    @Override
    public Collection<String> doSharding(Collection<String> collection, RangeShardingValue<String> rangeShardingValue) {
        log.debug("### collection:{}", JSON.toJSONString(collection));
        log.debug("### rangeShardingValue:{}", JSON.toJSONString(rangeShardingValue));
        Collection<String> collect = new ArrayList<>();
        Range<String> valueRange = rangeShardingValue.getValueRange();
        /*for (Long i = valueRange.lowerEndpoint(); i <= valueRange.upperEndpoint(); i++) {
            for (String each : collection) {
                if (each.endsWith(i % collection.size() + "")) {
                    collect.add(each);
                }
            }
        }*/
        String logicTableName = rangeShardingValue.getLogicTableName();
        long lowerEndpoint = Long.parseLong(valueRange.lowerEndpoint());
        long upperEndpoint = Long.parseLong(valueRange.upperEndpoint());
        for (long i = lowerEndpoint; i <= upperEndpoint; i++) {
            String moduloTableName = logicTableName + (i % collection.size());
            if (collection.contains(moduloTableName.toLowerCase())) {
                collect.add(moduloTableName);
            }
        }
        log.info("### Modulo Tables: {}", collect);
        return collect;
    }
}
