package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.spi.predicate.Marker;
import com.facebook.presto.spi.predicate.Range;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.predicate.ValueSet;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.DomainTranslator;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.ChildReplacer;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.stream.Collectors;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.class */
public class WindowFilterPushDown implements PlanOptimizer {
    private static final Signature ROW_NUMBER_SIGNATURE = new Signature("row_number", FunctionKind.WINDOW, TypeSignature.parseTypeSignature("bigint"), (List<TypeSignature>) ImmutableList.of());
    private final Metadata metadata;

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<Void> {
        private final PlanNodeIdAllocator idAllocator;
        private final Metadata metadata;
        private final Session session;
        private final Map<Symbol, Type> types;

        private Rewriter(PlanNodeIdAllocator planNodeIdAllocator, Metadata metadata, Session session, Map<Symbol, Type> map) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            this.types = ImmutableMap.copyOf((Map) Objects.requireNonNull(map, "types is null"));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitWindow(WindowNode windowNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            Preconditions.checkState(windowNode.getWindowFunctions().size() == 1, "WindowFilterPushdown requires that WindowNodes contain exactly one window function");
            PlanNode rewrite = rewriteContext.rewrite(windowNode.getSource());
            return canReplaceWithRowNumber(windowNode) ? new RowNumberNode(this.idAllocator.getNextId(), rewrite, windowNode.getPartitionBy(), (Symbol) Iterables.getOnlyElement(windowNode.getWindowFunctions().keySet()), Optional.empty(), Optional.empty()) : ChildReplacer.replaceChildren(windowNode, ImmutableList.of(rewrite));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitLimit(LimitNode limitNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            if (limitNode.getCount() > 2147483647L) {
                return rewriteContext.defaultRewrite(limitNode);
            }
            PlanNode rewrite = rewriteContext.rewrite(limitNode.getSource());
            int intExact = Math.toIntExact(limitNode.getCount());
            if (rewrite instanceof RowNumberNode) {
                RowNumberNode mergeLimit = mergeLimit((RowNumberNode) rewrite, intExact);
                if (mergeLimit.getPartitionBy().isEmpty()) {
                    return mergeLimit;
                }
                rewrite = mergeLimit;
            } else if ((rewrite instanceof WindowNode) && canOptimizeWindowFunction((WindowNode) rewrite)) {
                WindowNode windowNode = (WindowNode) rewrite;
                Verify.verify(!windowNode.getOrderBy().isEmpty());
                TopNRowNumberNode convertToTopNRowNumber = convertToTopNRowNumber(windowNode, intExact);
                if (windowNode.getPartitionBy().isEmpty()) {
                    return convertToTopNRowNumber;
                }
                rewrite = convertToTopNRowNumber;
            }
            return ChildReplacer.replaceChildren(limitNode, ImmutableList.of(rewrite));
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanVisitor
        public PlanNode visitFilter(FilterNode filterNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            PlanNode rewrite = rewriteContext.rewrite(filterNode.getSource());
            TupleDomain<Symbol> tupleDomain = DomainTranslator.fromPredicate(this.metadata, this.session, filterNode.getPredicate(), this.types).getTupleDomain();
            if (rewrite instanceof RowNumberNode) {
                Symbol rowNumberSymbol = ((RowNumberNode) rewrite).getRowNumberSymbol();
                OptionalInt extractUpperBound = extractUpperBound(tupleDomain, rowNumberSymbol);
                if (extractUpperBound.isPresent()) {
                    return rewriteFilterSource(filterNode, mergeLimit((RowNumberNode) rewrite, extractUpperBound.getAsInt()), rowNumberSymbol, extractUpperBound.getAsInt());
                }
            } else if ((rewrite instanceof WindowNode) && canOptimizeWindowFunction((WindowNode) rewrite)) {
                WindowNode windowNode = (WindowNode) rewrite;
                Symbol symbol = (Symbol) ((Map.Entry) Iterables.getOnlyElement(windowNode.getWindowFunctions().entrySet())).getKey();
                OptionalInt extractUpperBound2 = extractUpperBound(tupleDomain, symbol);
                if (extractUpperBound2.isPresent()) {
                    return rewriteFilterSource(filterNode, convertToTopNRowNumber(windowNode, extractUpperBound2.getAsInt()), symbol, extractUpperBound2.getAsInt());
                }
            }
            return ChildReplacer.replaceChildren(filterNode, ImmutableList.of(rewrite));
        }

        private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode planNode, Symbol symbol, int i) {
            DomainTranslator.ExtractionResult fromPredicate = DomainTranslator.fromPredicate(this.metadata, this.session, filterNode.getPredicate(), this.types);
            TupleDomain<Symbol> tupleDomain = fromPredicate.getTupleDomain();
            if (!isEqualRange(tupleDomain, symbol, i)) {
                return new FilterNode(filterNode.getId(), planNode, filterNode.getPredicate());
            }
            Expression combineConjuncts = ExpressionUtils.combineConjuncts(fromPredicate.getRemainingExpression(), DomainTranslator.toPredicate(TupleDomain.withColumnDomains((Map) ((Map) tupleDomain.getDomains().get()).entrySet().stream().filter(entry -> {
                return !((Symbol) entry.getKey()).equals(symbol);
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            })))));
            return combineConjuncts.equals(BooleanLiteral.TRUE_LITERAL) ? planNode : new FilterNode(filterNode.getId(), planNode, combineConjuncts);
        }

        private static boolean isEqualRange(TupleDomain<Symbol> tupleDomain, Symbol symbol, long j) {
            if (tupleDomain.isNone()) {
                return false;
            }
            Domain domain = (Domain) ((Map) tupleDomain.getDomains().get()).get(symbol);
            return domain.getValues().equals(ValueSet.ofRanges(Range.lessThanOrEqual(domain.getType(), Long.valueOf(j)), new Range[0]));
        }

        private static OptionalInt extractUpperBound(TupleDomain<Symbol> tupleDomain, Symbol symbol) {
            Domain domain;
            if (!tupleDomain.isNone() && (domain = (Domain) ((Map) tupleDomain.getDomains().get()).get(symbol)) != null) {
                ValueSet values = domain.getValues();
                if (values.isAll() || values.isNone() || values.getRanges().getRangeCount() <= 0) {
                    return OptionalInt.empty();
                }
                Range span = values.getRanges().getSpan();
                if (span.getHigh().isUpperUnbounded()) {
                    return OptionalInt.empty();
                }
                Verify.verify(domain.getType().equals(BigintType.BIGINT));
                long longValue = ((Long) span.getHigh().getValue()).longValue();
                if (span.getHigh().getBound() == Marker.Bound.BELOW) {
                    longValue--;
                }
                return longValue > 2147483647L ? OptionalInt.empty() : OptionalInt.of(Math.toIntExact(longValue));
            }
            return OptionalInt.empty();
        }

        private static RowNumberNode mergeLimit(RowNumberNode rowNumberNode, int i) {
            if (rowNumberNode.getMaxRowCountPerPartition().isPresent()) {
                i = Math.min(rowNumberNode.getMaxRowCountPerPartition().get().intValue(), i);
            }
            return new RowNumberNode(rowNumberNode.getId(), rowNumberNode.getSource(), rowNumberNode.getPartitionBy(), rowNumberNode.getRowNumberSymbol(), Optional.of(Integer.valueOf(i)), rowNumberNode.getHashSymbol());
        }

        private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int i) {
            return new TopNRowNumberNode(this.idAllocator.getNextId(), windowNode.getSource(), windowNode.getSpecification(), (Symbol) Iterables.getOnlyElement(windowNode.getWindowFunctions().keySet()), i, false, Optional.empty());
        }

        private static boolean canReplaceWithRowNumber(WindowNode windowNode) {
            return canOptimizeWindowFunction(windowNode) && windowNode.getOrderBy().isEmpty();
        }

        private static boolean canOptimizeWindowFunction(WindowNode windowNode) {
            if (windowNode.getWindowFunctions().size() != 1) {
                return false;
            }
            return isRowNumberSignature(windowNode.getWindowFunctions().get((Symbol) ((Map.Entry) Iterables.getOnlyElement(windowNode.getWindowFunctions().entrySet())).getKey()).getSignature());
        }

        private static boolean isRowNumberSignature(Signature signature) {
            return signature.equals(WindowFilterPushDown.ROW_NUMBER_SIGNATURE);
        }
    }

    public WindowFilterPushDown(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, Map<Symbol, Type> map, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(map, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(planNodeIdAllocator, this.metadata, session, map), planNode, null);
    }
}
