package com.bizvane.utils.tenant;

import com.aliyun.openservices.ons.api.*;
import com.bizvane.utils.jacksonutils.JacksonUtil;
import com.bizvane.utils.redisutils.SpringContextHolder;
import com.bizvane.utils.tokens.JWTUtil;
import com.bizvane.utils.tokens.SysAccountPO;
import com.bizvane.utils.tokens.TokenUtils;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
import org.springframework.util.StopWatch;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.sql.Connection;
import java.util.*;


@Intercepts(@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}))
public class QuarantineIntercepts implements Interceptor {

    public static final Logger logger = LoggerFactory.getLogger(QuarantineIntercepts.class);

    //mybatis plugin configuration xml
    public static final String ENABLE_INTERCEPTS = "enableIntercepts";
    public static final String IGNORE_TABLE = "ignoreTable";

    //enableIntercepts 的属性 禁用拦截
    public static final String DISABLE = "disable";
    //enableIntercepts 的属性 启用拦截
    public static final String ENABLE = "enable";
    //拼接的字段
    public static final String COMPANY_ID = "sys_company_id";

    //mq tag
    private static final String SQL_SELECT = "sql_select";
    //mq topic
    private static final String PUBLIC_BIZVANE_SQL_INTERCEPT_BEAN = "public_bizvane_sql_intercept_bean";

    /**
     * 获取请求体
     *
     * @return
     */
    public HttpServletRequest getRequest() {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if (null == requestAttributes) {
            return null;
        }
        return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
    }

    /**
     * 被拦截的sql中不能包含 "" 双引号，否则无法解析
     *
     * @param invocation
     * @return
     * @throws Throwable
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        QuarantineInterceptsBean bean = new QuarantineInterceptsBean();
        RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
        StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
        MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
        Class<?> classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf(".")));
        String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1);
        boolean flag = skipIntercepts(classType, mName);
        BoundSql boundSql = delegate.getBoundSql();
        String obj = boundSql.getSql().replaceAll("[\\s]+", " ");
        String sql = null;
        boolean userIdentity = true;
        try {
            if (SqlCommandType.SELECT != mappedStatement.getSqlCommandType() || flag) {
                return invocation.proceed();
            }

            if (validatorIntercepts()) return invocation.proceed();

            SysAccountPO sysAccountPO = TokenUtils.getStageUser(getRequest());
//            if (null == sysAccountPO) {
//                //本地测试方法
//                sysAccountPO = getStageUser(getRequest());
//            }
            if (null == sysAccountPO || null == sysAccountPO.getSysCompanyId()
                    || 0 == sysAccountPO.getSysCompanyId()) {
                userIdentity = false;
                return invocation.proceed();
            }

            StopWatch sw = new StopWatch("sql执行总耗时");
            try {
                sw.start("拼接耗时");
                sql = buildQuery(obj, sysAccountPO.getSysCompanyId().toString());
                ReflectUtil.setFieldValue(boundSql, "sql", sql);
                sw.stop();
                sw.start("运行耗时");
                return invocation.proceed();
            } finally {
                sw.stop();
                logger.info(sw.prettyPrint());
                logger.info("sql执行总耗时：[{}ms]", sw.getTotalTimeMillis());
            }
        } finally {
            //临时改动,只有商帆的服务器uat、prod的环境才会发送mq
            if (StringUtils.startsWithAny(InetAddress.getLocalHost().getHostName(), "app", "new")) {
                this.sendSQLBean(bean, mappedStatement, classType, mName, flag, obj, sql, userIdentity);
            }
        }

    }


    //class and method exists annotation skip intercepts
    protected boolean skipIntercepts(Class<?> classType, String mName) {

        if (StringUtils.endsWith(mName, "_COUNT")) {
            String subMethodName = StringUtils.substringBefore(mName, "_COUNT");
            for (Method method : classType.getDeclaredMethods()) {
                if (subMethodName.equals(method.getName()) && method.isAnnotationPresent(QuarantineAnnotation.class)) {
                    return true;
                }
            }
        }
        if (classType.isAnnotationPresent(QuarantineAnnotation.class)) {
            return true;
        }
        for (Method method : classType.getDeclaredMethods()) {
            if (mName.equals(method.getName()) && method.isAnnotationPresent(QuarantineAnnotation.class)) {
                return true;
            }
        }
        return false;
    }


    public static SysAccountPO getStageUser(HttpServletRequest request) {
        String sysAccountPoString = request.getHeader("stageToken");
        return JWTUtil.unsign(sysAccountPoString, SysAccountPO.class);
    }


    public static String addCompanyIdCondition(String coreQuery, String companyId) {
        Select select = createSelect(coreQuery);
        PlainSelect plainSelect = getPlainSelect(select);

        if (plainSelect.getFromItem() instanceof Table) {
            Table table = (Table) plainSelect.getFromItem();
            appendCondition(companyId, table, select, plainSelect);
        }

        if (plainSelect.getFromItem() instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) plainSelect.getFromItem();
            SelectBody selectBody = subSelect.getSelectBody();
            Select select1 = createSelect(selectBody.toString());
            PlainSelect plainSelect1 = getPlainSelect(select1);
            Table table = (Table) plainSelect1.getFromItem();
            appendCondition(companyId, table, select1, plainSelect1);
            ((SubSelect) plainSelect.getFromItem()).setSelectBody(plainSelect1);
        }
        return select.toString();
    }

    protected static void appendCondition(String companyId, Table table, Select select, PlainSelect plainSelect) {
        Map<String, String> map = new LinkedHashMap<>();

        if (null != table.getAlias() && StringUtils.isNotBlank(table.getAlias().getName())) {
            map.put(table.getName(), table.getAlias().getName());
        } else {
            map.put(table.getName(), "${}");
        }
        if (null != plainSelect.getJoins()) {
            for (Join join : plainSelect.getJoins()) {
                Table rightItem = (Table) join.getRightItem();
                if (rightItem.getAlias() == null || StringUtils.isBlank(rightItem.getAlias().getName())) {
                    map.put(rightItem.getName(), "${}");
                } else {
                    map.put(rightItem.getName(), rightItem.getAlias().getName());
                }
            }
        }

        //当前sql是否存在被忽略的table，存在则不拼接
        List<String> ignoreTableList = getAllIgnoreTable();
        boolean flag = false;
        if (CollectionUtils.isNotEmpty(ignoreTableList)) {
            for (String str : ignoreTableList) {
                if (map.containsKey(str)) {
                    flag = true;
                    break;
                }
            }
        }

        if (!flag) {
            String alias = "";
            EqualsTo equalsTo = new EqualsTo();
            for (Map.Entry<String, String> entry : map.entrySet()) {
                if (StringUtils.isNotBlank(entry.getValue())) {
                    alias = entry.getValue();
                    break;
                }
            }
            StringBuffer sb = new StringBuffer();
            if (StringUtils.isNotBlank(alias) && !"${}".equals(alias)) {
                sb.append(alias).append(".").append(COMPANY_ID);
            } else {
                sb.append(COMPANY_ID);
            }
            equalsTo.setLeftExpression(new Column(sb.toString()));
            equalsTo.setRightExpression(new LongValue(companyId));
            addWhereCondition(select, equalsTo);
        }
    }


    public static String buildQuery(String sql, String condition) {
        return addCompanyIdCondition(sql, condition);
    }

    public static Select createSelect(String sql) {
        try {
            return (Select) CCJSqlParserUtil.parse(sql);
        } catch (JSQLParserException e) {
            throw new IllegalStateException("SQL parsing problem!", e);
        }
    }

    public static void addWhereCondition(Select select, Expression condition) {
        addWhereCondition(getPlainSelect(select), condition);
    }

    private static void addWhereCondition(PlainSelect plainSelect, Expression condition) {
        if (plainSelect.getWhere() == null) {
            plainSelect.setWhere(condition);
            return;
        }
        AndExpression andExpression = new AndExpression(plainSelect.getWhere(), condition);
        plainSelect.setWhere(andExpression);
    }

    private static PlainSelect getPlainSelect(Select select) {
        if (select.getSelectBody() instanceof PlainSelect) {
            return (PlainSelect) select.getSelectBody();
        }
        throw new UnsupportedOperationException("Not supported yet.");
    }


    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }


    public static List<String> getAllIgnoreTable() {
        if (isaBoolean()) {
            return null;
        }
        return Arrays.asList(StringUtils.split(QuarantineContextHolder.getQuarantine.getIgnoreTable(), ","));
    }


    private boolean validatorIntercepts() {
        if (null == QuarantineContextHolder.getQuarantine ||
                StringUtils.isBlank(QuarantineContextHolder.getQuarantine.getEnableIntercepts()) ||
                DISABLE.equals(QuarantineContextHolder.getQuarantine.getEnableIntercepts()) ||
                null == getRequest()) {
            return true;
        }
        return !ENABLE.equals(QuarantineContextHolder.getQuarantine.getEnableIntercepts());
    }


    private static boolean isaBoolean() {
        return null == QuarantineContextHolder.getQuarantine ||
                null == QuarantineContextHolder.getQuarantine.getIgnoreTable();
    }

    @Override
    public void setProperties(Properties properties) {
        QuarantineEntity quarantineEntity = new QuarantineEntity();
        quarantineEntity.setEnableIntercepts((String) properties.get(ENABLE_INTERCEPTS));
        quarantineEntity.setIgnoreTable((String) properties.get(IGNORE_TABLE));
        QuarantineContextHolder.setQuarantineEntity(quarantineEntity);
    }


    public static class ReflectUtil {
        /**
         * 利用反射获取指定对象的指定属性
         *
         * @param obj       目标对象
         * @param fieldName 目标属性
         * @return 目标属性的值
         */
        public static Object getFieldValue(Object obj, String fieldName) {
            Object result = null;
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                field.setAccessible(true);
                try {
                    result = field.get(obj);
                } catch (IllegalArgumentException e) {
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
            return result;
        }


        /**
         * 利用反射获取指定对象里面的指定属性
         *
         * @param obj       目标对象
         * @param fieldName 目标属性
         * @return 目标字段
         */
        private static Field getField(Object obj, String fieldName) {
            Field field = null;
            Class<?> clazz = obj.getClass();
            for (; clazz != Object.class; ) {
                if ("mappedStatement".equals(fieldName)) {
                    clazz = clazz.getSuperclass();
                }
                try {
                    field = clazz.getDeclaredField(fieldName);
                    break;
                } catch (NoSuchFieldException e) {
                    e.printStackTrace();
                }
            }
            return field;
        }


        /**
         * 利用反射设置指定对象的指定属性为指定的值
         *
         * @param obj        目标对象
         * @param fieldName  目标属性
         * @param fieldValue 目标值
         */
        public static void setFieldValue(Object obj, String fieldName,
                                         String fieldValue) {
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                try {
                    field.setAccessible(true);
                    field.set(obj, fieldValue);
                } catch (IllegalArgumentException e) {
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        }
    }


    //发送记录拦截的前后的sql
    protected void sendSQLBean(QuarantineInterceptsBean bean, MappedStatement mappedStatement, Class<?> classType, String mName, boolean flag, String obj, String sql, boolean userIdentity) {
        if (SqlCommandType.SELECT != mappedStatement.getSqlCommandType()) return;
        Producer producer = getProducer();
        if (producer == null) return;

        try {
            bean.setApplicationName(getApplicationName());
            bean.setClazz(classType.getName());
            bean.setMethod(mName);
            bean.setJoinBefore(obj);
            bean.setJoinAfter(sql);
            //1已排除,2未排除
            bean.setExcluded(flag ? 1 : 2);
            //1有登陆身份,2无登陆身份
            bean.setUserIdentity(userIdentity ? 1 : 2);
            bean.setCreateDate(new Date());
            bean.setUpdateDate(new Date());
            Message message = new Message();
            message.setTag(SQL_SELECT);
            message.setTopic(PUBLIC_BIZVANE_SQL_INTERCEPT_BEAN);
            message.setBody(JacksonUtil.bean2Json(bean).getBytes());
            producer.sendAsync(message, new SendCallback() {
                @Override
                public void onSuccess(SendResult sendResult) {
                    logger.debug("发送拦截sql成功！", sendResult.getMessageId());
                }

                @Override
                public void onException(OnExceptionContext onExceptionContext) {
                    logger.error("发送sql拦截数据异常！", onExceptionContext.getException());
                }
            });
        } catch (Exception e) {
            e.printStackTrace();
            logger.error("发送sql拦截数据异常！", e);
        }
    }

    //获取mq生产者对象
    private Producer getProducer() {
        Producer producer;
        try {
            Object object = SpringContextHolder.getBean("producer");
            if (!(object instanceof Producer)) {
                return null;
            }

            producer = (Producer)object;
        } catch (Exception var3) {
            logger.error("MQ的发送对象出现异常，请检查配置信息!", var3);
            return null;
        }

        if (producer == null) {
            logger.error("MQ的发送对象为空，请检查配置信息!");
            return null;
        } else {
            return producer;
        }
    }

    private String getApplicationName() {
        Environment environment = null;
        try {
            environment = (Environment) SpringContextHolder.getBean("environment");
        } catch (Exception e) {
            logger.error("未获取到系统变量对象!", e);
            e.printStackTrace();
        }

        if (environment == null) {
            logger.error("系统环境变量为空!");
            return null;
        }
        String applicationName = environment.getProperty("spring.application.name");
        if (StringUtils.isBlank(applicationName)) {
            logger.error("应用名为空!");
            return null;
        }
        return applicationName;
    }

}
