package com.bizvane.config;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.filter.Filter;
import ch.qos.logback.core.spi.FilterReply;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.regex.Matcher;
import java.util.regex.Pattern;


public class LogbackSensitiveDataFilter extends Filter<ILoggingEvent> {

    private Level level;


    private static final Pattern PHONE_PATTERN_REGEX_1 = Pattern.compile(
            "\"phone\"\\s*:\\s*\"(\\d{3})(\\d{4})(\\d{4})\""
    );

    private static final Pattern PHONE_PATTERN_REGEX_2 = Pattern.compile(
            "\\\\\"phone\\\\\"\\s*:\\s*\\\\\"(\\d{3})(\\d{4})(\\d{4})\\\\\""
    );


    @Override
    public FilterReply decide(ILoggingEvent event) {
        String message = event.getFormattedMessage();
        Matcher matcher1 = PHONE_PATTERN_REGEX_1.matcher(message);
        Matcher matcher2 = PHONE_PATTERN_REGEX_2.matcher(message);
        if (matcher1.find() || matcher2.find()) {
            if (matcher1.find()) {
                message = maskPhoneNumberRegex1(message);
            }
            if (matcher2.find()) {
                message = maskPhoneNumberRegex2(message);
            }
            logSanitizedMessage(event, message);
            return FilterReply.DENY;
        }

        return FilterReply.NEUTRAL;
    }

    private String maskPhoneNumberRegex1(String message) {

        return PHONE_PATTERN_REGEX_1.matcher(message).replaceAll("\"phone\":\"$1****$3\"");
    }

    private String maskPhoneNumberRegex2(String message) {

        return PHONE_PATTERN_REGEX_2.matcher(message).replaceAll("\\\\\"phone\\\\\":\\\\\"$1****$3\\\\\"");
    }


    private void logSanitizedMessage(ILoggingEvent event, String sanitizedMessage) {
        Logger logger = LoggerFactory.getLogger(event.getLoggerName());
        switch (event.getLevel().levelInt) {
            case Level.ERROR_INT:
                logger.error(sanitizedMessage);
                break;
            case Level.WARN_INT:
                logger.warn(sanitizedMessage);
                break;
            case Level.INFO_INT:
                logger.info(sanitizedMessage);
                break;
            case Level.DEBUG_INT:
                logger.debug(sanitizedMessage);
                break;
            case Level.TRACE_INT:
                logger.trace(sanitizedMessage);
                break;
            default:
                logger.info(sanitizedMessage);
                break;
        }
    }

    public void setLevel(String level) {
        this.level = Level.toLevel(level);
    }

    @Override
    public void start() {
        if (this.level != null) {
            super.start();
        }
    }
}