package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.class */
public class OptimizeMixedDistinctAggregations implements PlanOptimizer {
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations$AggregateInfo.class */
    public static class AggregateInfo {
        private final List<Symbol> groupBySymbols;
        private final Symbol mask;
        private final Map<Symbol, FunctionCall> aggregations;
        private final Map<Symbol, Signature> functions;
        private Map<Symbol, Symbol> newNonDistinctAggregateSymbols;
        private Symbol newDistinctAggregateSymbol;
        private boolean foundMarkDistinct;

        public AggregateInfo(List<Symbol> list, Symbol symbol, Map<Symbol, FunctionCall> map, Map<Symbol, Signature> map2) {
            this.groupBySymbols = ImmutableList.copyOf(list);
            this.mask = symbol;
            this.aggregations = ImmutableMap.copyOf(map);
            this.functions = ImmutableMap.copyOf(map2);
        }

        public List<Symbol> getOriginalNonDistinctAggregateArgs() {
            return (List) this.aggregations.values().stream().filter(functionCall -> {
                return !functionCall.isDistinct();
            }).flatMap(functionCall2 -> {
                return functionCall2.getArguments().stream();
            }).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public List<Symbol> getOriginalDistinctAggregateArgs() {
            return (List) this.aggregations.values().stream().filter((v0) -> {
                return v0.isDistinct();
            }).flatMap(functionCall -> {
                return functionCall.getArguments().stream();
            }).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public Symbol getNewDistinctAggregateSymbol() {
            return this.newDistinctAggregateSymbol;
        }

        public void setNewDistinctAggregateSymbol(Symbol symbol) {
            this.newDistinctAggregateSymbol = symbol;
        }

        public Map<Symbol, Symbol> getNewNonDistinctAggregateSymbols() {
            return this.newNonDistinctAggregateSymbols;
        }

        public void setNewNonDistinctAggregateSymbols(Map<Symbol, Symbol> map) {
            this.newNonDistinctAggregateSymbols = map;
        }

        public Symbol getMask() {
            return this.mask;
        }

        public List<Symbol> getGroupBySymbols() {
            return this.groupBySymbols;
        }

        public Map<Symbol, FunctionCall> getAggregations() {
            return this.aggregations;
        }

        public Map<Symbol, Signature> getFunctions() {
            return this.functions;
        }

        public void foundMarkDistinct() {
            this.foundMarkDistinct = true;
        }

        public boolean isFoundMarkDistinct() {
            return this.foundMarkDistinct;
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations$Optimizer.class */
    private static class Optimizer extends SimplePlanRewriter<Optional<AggregateInfo>> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        private Optimizer(PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> rewriteContext) {
            ImmutableSet copyOf = ImmutableSet.copyOf(aggregationNode.getMasks().values());
            if (copyOf.size() != 1 || aggregationNode.getMasks().size() == aggregationNode.getAggregations().size()) {
                return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
            }
            if (aggregationNode.getAggregations().values().stream().map((v0) -> {
                return v0.getFilter();
            }).anyMatch((v0) -> {
                return v0.isPresent();
            })) {
                return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
            }
            AggregateInfo aggregateInfo = new AggregateInfo(aggregationNode.getGroupingKeys(), (Symbol) Iterables.getOnlyElement(copyOf), aggregationNode.getAggregations(), aggregationNode.getFunctions());
            if (!checkAllEquatableTypes(aggregateInfo)) {
                return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
            }
            PlanNode rewrite = rewriteContext.rewrite(aggregationNode.getSource(), Optional.of(aggregateInfo));
            if (!aggregateInfo.isFoundMarkDistinct()) {
                return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
            }
            ImmutableMap.Builder builder = ImmutableMap.builder();
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            for (Map.Entry<Symbol, FunctionCall> entry : aggregationNode.getAggregations().entrySet()) {
                FunctionCall value = entry.getValue();
                if (entry.getValue().isDistinct()) {
                    builder.put(entry.getKey(), new FunctionCall(value.getName(), value.getWindow(), false, ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference())));
                    builder2.put(entry.getKey(), aggregationNode.getFunctions().get(entry.getKey()));
                } else {
                    Symbol symbol = aggregateInfo.getNewNonDistinctAggregateSymbols().get(entry.getKey());
                    QualifiedName of = QualifiedName.of("arbitrary");
                    builder.put(entry.getKey(), new FunctionCall(of, value.getWindow(), false, ImmutableList.of(symbol.toSymbolReference())));
                    builder2.put(entry.getKey(), getFunctionSignature(of, symbol));
                }
            }
            return new AggregationNode(this.idAllocator.getNextId(), rewrite, builder.build(), builder2.build(), Collections.emptyMap(), aggregationNode.getGroupingSets(), aggregationNode.getStep(), Optional.empty(), aggregationNode.getGroupIdSymbol());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitMarkDistinct(MarkDistinctNode markDistinctNode, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> rewriteContext) {
            Optional<AggregateInfo> optional = rewriteContext.get();
            if (!optional.isPresent() || !optional.get().getMask().equals(markDistinctNode.getMarkerSymbol())) {
                return rewriteContext.defaultRewrite(markDistinctNode, Optional.empty());
            }
            optional.get().foundMarkDistinct();
            PlanNode rewrite = rewriteContext.rewrite(markDistinctNode.getSource(), Optional.empty());
            ArrayList arrayList = new ArrayList();
            List<Symbol> groupBySymbols = optional.get().getGroupBySymbols();
            List<Symbol> originalNonDistinctAggregateArgs = optional.get().getOriginalNonDistinctAggregateArgs();
            Symbol symbol = (Symbol) Iterables.getOnlyElement(optional.get().getOriginalDistinctAggregateArgs());
            Symbol symbol2 = symbol;
            if (originalNonDistinctAggregateArgs.contains(symbol)) {
                Symbol newSymbol = this.symbolAllocator.newSymbol(symbol.getName(), this.symbolAllocator.getTypes().get(symbol));
                originalNonDistinctAggregateArgs.set(originalNonDistinctAggregateArgs.indexOf(symbol), newSymbol);
                symbol2 = newSymbol;
            }
            arrayList.addAll(groupBySymbols);
            arrayList.addAll(originalNonDistinctAggregateArgs);
            arrayList.add(symbol);
            Symbol newSymbol2 = this.symbolAllocator.newSymbol("group", (Type) BigintType.BIGINT);
            GroupIdNode createGroupIdNode = createGroupIdNode(groupBySymbols, originalNonDistinctAggregateArgs, symbol, symbol2, newSymbol2, arrayList, rewrite);
            ArrayList arrayList2 = new ArrayList();
            arrayList2.addAll(groupBySymbols);
            arrayList2.add(symbol);
            arrayList2.add(newSymbol2);
            ImmutableMap.Builder builder = ImmutableMap.builder();
            return createProjectNode(createNonDistinctAggregation(optional.get(), symbol, symbol2, arrayList2, createGroupIdNode, markDistinctNode, builder), optional.get(), symbol, newSymbol2, builder.build());
        }

        private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo) {
            Iterator<Symbol> it = aggregateInfo.getOriginalNonDistinctAggregateArgs().iterator();
            while (it.hasNext()) {
                if (!this.symbolAllocator.getTypes().get(it.next()).isComparable()) {
                    return false;
                }
            }
            return this.symbolAllocator.getTypes().get(aggregateInfo.getMask()).isComparable();
        }

        private ProjectNode createProjectNode(AggregationNode aggregationNode, AggregateInfo aggregateInfo, Symbol symbol, Symbol symbol2, Map<Symbol, Symbol> map) {
            Assignments.Builder builder = Assignments.builder();
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            for (Symbol symbol3 : aggregationNode.getOutputSymbols()) {
                if (symbol.equals(symbol3)) {
                    Symbol newSymbol = this.symbolAllocator.newSymbol("expr", this.symbolAllocator.getTypes().get(symbol3));
                    aggregateInfo.setNewDistinctAggregateSymbol(newSymbol);
                    builder.put(newSymbol, createIfExpression(symbol2.toSymbolReference(), new Cast(new LongLiteral("1"), "bigint"), ComparisonExpressionType.EQUAL, symbol3.toSymbolReference(), this.symbolAllocator.getTypes().get(symbol3)));
                } else if (map.containsKey(symbol3)) {
                    Symbol newSymbol2 = this.symbolAllocator.newSymbol("expr", this.symbolAllocator.getTypes().get(symbol3));
                    builder2.put(map.get(symbol3), newSymbol2);
                    builder.put(newSymbol2, createIfExpression(symbol2.toSymbolReference(), new Cast(new LongLiteral("0"), "bigint"), ComparisonExpressionType.EQUAL, symbol3.toSymbolReference(), this.symbolAllocator.getTypes().get(symbol3)));
                } else {
                    builder.put(symbol3, symbol3.toSymbolReference());
                }
            }
            builder.put(aggregateInfo.getMask(), new NullLiteral());
            aggregateInfo.setNewNonDistinctAggregateSymbols(builder2.build());
            return new ProjectNode(this.idAllocator.getNextId(), aggregationNode, builder.build());
        }

        private GroupIdNode createGroupIdNode(List<Symbol> list, List<Symbol> list2, Symbol symbol, Symbol symbol2, Symbol symbol3, List<Symbol> list3, PlanNode planNode) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList2.addAll(list);
            arrayList2.addAll(list2);
            arrayList.add(arrayList2);
            ArrayList arrayList3 = new ArrayList();
            arrayList3.addAll(list);
            arrayList3.add(symbol);
            arrayList.add(arrayList3);
            return new GroupIdNode(this.idAllocator.getNextId(), planNode, arrayList, (Map) list3.stream().collect(Collectors.toMap(symbol4 -> {
                return symbol4;
            }, symbol5 -> {
                return symbol5.equals(symbol2) ? symbol : symbol5;
            })), ImmutableMap.of(), symbol3);
        }

        private AggregationNode createNonDistinctAggregation(AggregateInfo aggregateInfo, Symbol symbol, Symbol symbol2, List<Symbol> list, GroupIdNode groupIdNode, MarkDistinctNode markDistinctNode, ImmutableMap.Builder builder) {
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            ImmutableMap.Builder builder3 = ImmutableMap.builder();
            for (Map.Entry<Symbol, FunctionCall> entry : aggregateInfo.getAggregations().entrySet()) {
                FunctionCall value = entry.getValue();
                if (!value.isDistinct()) {
                    Symbol newSymbol = this.symbolAllocator.newSymbol((Expression) entry.getKey().toSymbolReference(), this.symbolAllocator.getTypes().get(entry.getKey()));
                    builder.put(newSymbol, entry.getKey());
                    if (symbol2.equals(symbol)) {
                        builder2.put(newSymbol, value);
                    } else if (value.getArguments().contains(symbol.toSymbolReference())) {
                        ImmutableList.Builder builder4 = ImmutableList.builder();
                        for (Expression expression : value.getArguments()) {
                            if (symbol.toSymbolReference().equals(expression)) {
                                builder4.add(symbol2.toSymbolReference());
                            } else {
                                builder4.add(expression);
                            }
                        }
                        builder2.put(newSymbol, new FunctionCall(value.getName(), value.getWindow(), false, builder4.build()));
                    } else {
                        builder2.put(newSymbol, value);
                    }
                    builder3.put(newSymbol, aggregateInfo.getFunctions().get(entry.getKey()));
                }
            }
            return new AggregationNode(this.idAllocator.getNextId(), groupIdNode, builder2.build(), builder3.build(), Collections.emptyMap(), ImmutableList.of(list), AggregationNode.Step.SINGLE, markDistinctNode.getHashSymbol(), Optional.empty());
        }

        private Signature getFunctionSignature(QualifiedName qualifiedName, Symbol symbol) {
            return this.metadata.getFunctionRegistry().resolveFunction(qualifiedName, ImmutableList.of(new TypeSignatureProvider(this.symbolAllocator.getTypes().get(symbol).getTypeSignature())));
        }

        private static IfExpression createIfExpression(Expression expression, Expression expression2, ComparisonExpressionType comparisonExpressionType, Expression expression3, Type type) {
            return new IfExpression(new ComparisonExpression(comparisonExpressionType, expression, expression2), expression3, new Cast(new NullLiteral(), type.getDisplayName()));
        }
    }

    public OptimizeMixedDistinctAggregations(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, Map<Symbol, Type> map, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        return SystemSessionProperties.isOptimizeDistinctAggregationEnabled(session) ? SimplePlanRewriter.rewriteWith(new Optimizer(planNodeIdAllocator, symbolAllocator, this.metadata), planNode, Optional.empty()) : planNode;
    }
}
