package com.bizvane.members.facade.utils;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.bizvane.centerstageservice.models.po.SysBrandPo;
import com.bizvane.centerstageservice.rpc.BrandServiceRpc;
import com.bizvane.utils.enumutils.SysResponseEnum;
import com.bizvane.utils.jacksonutils.JacksonUtil;
import com.bizvane.utils.redisutils.SpringContextHolder;
import com.bizvane.utils.responseinfo.ResponseData;
import com.github.pagehelper.Constant;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.beans.BeansException;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import java.util.stream.Collectors;

@Slf4j
@ConditionalOnProperty(name = "aes.mybatisPlugin", havingValue = "true")
@Component
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})})
public class CryptoCardNoPlugin implements Interceptor, ApplicationContextAware {

    private static EncryptPhoneProperties encryptPhoneProperties;

    private BrandServiceRpc brandServiceRpc;

    private static ThreadLocal<Map<String, Object>> mapThreadLocal = new ThreadLocal<>();

    protected static String KEY = null;
    protected static String IV = null;

    private static final String COUNT_SUFFIX = "_COUNT";
    private static final String EXAMPLE_SUFFIX = "Example";
    private static final String SYS_COMPANY_ID = "sysCompanyId";
    private static final String COMPANY_ID = "companyId";
    private static final String SYS_BRAND_ID = "sysBrandId";
    private static final String BRAND_ID = "brandId";
    private static final String METHOD = "method";
    private static final String SELECT_BY_PRIMARY_KEY = "selectByPrimaryKey";
    private static final String ORED_CRITERIA = "oredCriteria";
    private static final String CRITERIA = "criteria";
    private static final String CONDITION = "condition";
    private static final String VALUE = "value";
    private static final String SECOND_VALUE = "secondValue";
    private static final List<String> columnNameList = Arrays.asList(SYS_COMPANY_ID, COMPANY_ID, SYS_BRAND_ID, BRAND_ID);


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        String methodName = null;
        Method method = null;
        Object parameter = null;
        try {
            Object[] args = invocation.getArgs();
            MappedStatement mappedStatement = (MappedStatement) args[0];
            String methodPath = mappedStatement.getId();
            methodName = methodPath.substring(methodPath.lastIndexOf(".") + 1);
            method = getMethod(methodPath);
            if (method != null && args.length > 1 && needEncrypt((parameter = args[1]))) {

                // 设置本地方法
                setLocalMethod(method);
                // 参数加密
                paramByDataTypeCrypt(parameter, method, PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT);

                // 根据替换后的参数重新获取BoundSql并覆盖
                String invocationMethodName = invocation.getMethod().getName();
                if (StrUtil.equals(invocationMethodName, "query") && args.length > 5) {
                    BoundSql boundSqlOld = (BoundSql) args[5];
                    BoundSql boundSqlNew = mappedStatement.getBoundSql(parameter);
                    ReflectUtil.setFieldValue(boundSqlOld, "parameterObject", boundSqlNew.getParameterObject());
                    args[5] = boundSqlOld;
                } else if (StrUtil.equals(invocationMethodName, "update") && args.length == 2) {
                    args[1] = parameter;
                }
            }
        } catch (Exception e) {
            log.error("PhoneCryptoPlugin encrypt exception: {}", e);
        }

        Object resultObj = invocation.proceed();

        try {
            if (method != null && !isEndWithCount(methodName) && needEncrypt(resultObj)) {
                // 结果解密
                resultObj = resultByDataTypeDecrypt(resultObj);

                // 参数解密
                paramByDataTypeCrypt(parameter, method, PhoneEncryptUtil.EncryptDecryptPhoneEnum.DECRYPT);
            }

        } catch (ClassNotFoundException e) {
            log.error("PhoneCryptoPlugin decrypt exception: {}", e);
        } finally {
            // 清空线程变量
            removeLocalBrandId();
        }
        return resultObj;
    }

    private Method getMethod(String methodPath) throws ClassNotFoundException {
        String methodName = methodPath.substring(methodPath.lastIndexOf(".") + 1);
        String classPath = methodPath.substring(0, methodPath.lastIndexOf("."));

        Class<?> clazz = this.getClass().getClassLoader().loadClass(classPath);
        Object bean = SpringContextHolder.getBean(clazz);
        Class<?> realClass = bean.getClass();
        Method method = ReflectUtil.getMethodByName(realClass, methodName);
        if (method == null && isEndWithCount(methodName)) {
            method = ReflectUtil.getMethodByName(realClass, methodName.substring(0, methodName.length() - 6));
        }
        return method;
    }

    private void paramByDataTypeCrypt(Object paramObj, Method method, PhoneEncryptUtil.EncryptDecryptPhoneEnum encryptDecryptPhoneEnum) throws Exception{
        if (paramObj == null || method == null) {
            return;
        }
        String className = paramObj.getClass().getName();
        if (paramObj instanceof Map) {
            Map<Object, Object> paramMap = (Map<Object, Object>) paramObj;
            for(Map.Entry<Object, Object> entry: paramMap.entrySet()) {
                Object value = entry.getValue();
                if (value == null) {
                    continue;
                }
                if (value instanceof List) {
                    paramByDataTypeCrypt(value, method, encryptDecryptPhoneEnum);
                } else {
                    EncryptCardNo annotation = value.getClass().getAnnotation(EncryptCardNo.class);
                    if (annotation != null || isEndWithExample(value.getClass().getName())) {
                        paramByDataTypeCrypt(value, method, encryptDecryptPhoneEnum);
                    }
                }
            }
        } else if (isEndWithExample(className)) {
            selectForExample(method, paramObj);
        } else if (paramObj instanceof List) {
            List<Object> paramList = (List<Object>) paramObj;
            for (Object ob : paramList) {
                if (ob == null) {
                    continue;
                }
                Class<?> aClass = ob.getClass();
                EncryptCardNo annotation = aClass.getAnnotation(EncryptCardNo.class);
                if (annotation != null) {
                    paramByDataTypeCrypt(ob, method, encryptDecryptPhoneEnum);
                }
            }
        } else {
            phoneFieldProcess(paramObj, encryptDecryptPhoneEnum);
        }
    }

    private Object resultByDataTypeDecrypt(Object paramObj) throws ClassNotFoundException {

        PhoneEncryptUtil.EncryptDecryptPhoneEnum encryptDecryptPhoneEnum = PhoneEncryptUtil.EncryptDecryptPhoneEnum.DECRYPT;
        if (paramObj == null) {
            return paramObj;
        }
        if (paramObj instanceof List) {
            List<Object> paramList = (List<Object>) paramObj;
            for (int i = 0; i < paramList.size(); i++) {
                Object ob = paramList.get(i);
                if (ob == null) {
                    continue;
                }
                Class<?> aClass = ob.getClass();
                EncryptCardNo annotation = aClass.getAnnotation(EncryptCardNo.class);
                if (annotation != null) {
                    ob = resultByDataTypeDecrypt(ob);
                }
                paramList.set(i, ob);
            }
        } else {
            EncryptCardNo annotation = paramObj.getClass().getAnnotation(EncryptCardNo.class);
            if (annotation != null) {
                phoneFieldProcess(paramObj, encryptDecryptPhoneEnum);
            }
        }
        return paramObj;
    }

    private String stringParameterProcess(String paramObj, PhoneEncryptUtil.EncryptDecryptPhoneEnum encryptDecryptPhoneEnum) {
        try {
            Long localCompanyId = getLocalCompanyId();
            Method localMethod = getLocalMethod();
            // 如果是加密操作，且brandId为空，打印日志用于排查
            if (localCompanyId == null && encryptDecryptPhoneEnum.getCode() == PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT.getCode()) {
                log.error("encrypt phone brandId is null stringParameterProcess paramObj: {}, method: {}", JacksonUtil.bean2Map(paramObj), JacksonUtil.bean2Json(localMethod));
            }

            // 根据品牌id判断是否需要加解密
            boolean needEncrypt = needEncrypt();
            if (!needEncrypt) {
                return paramObj;
            }
            if (encryptDecryptPhoneEnum.getCode() == PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT.getCode()) {
                // 加密 phone 字段
                paramObj = PhoneEncryptUtil.encryptAES(KEY, IV, paramObj);
            } else {
                // 解密 phone 字段
                paramObj = PhoneEncryptUtil.decryptAES(KEY, IV, paramObj);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return paramObj;
    }

    private void phoneFieldProcess(Object parameter, PhoneEncryptUtil.EncryptDecryptPhoneEnum encryptDecryptPhoneEnum) {
        Class<?> aClass = parameter.getClass();
        List<Field> fieldsListWithAnnotation = FieldUtils.getFieldsListWithAnnotation(aClass, EncryptCardNo.class);
        if (CollUtil.isNotEmpty(fieldsListWithAnnotation)) {
            Long localCompanyId = getLocalCompanyId();
            Method localMethod = getLocalMethod();
            // 如果是加密操作，且brandId为空，打印日志用于排查
            if (localCompanyId == null && encryptDecryptPhoneEnum.getCode() == PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT.getCode()) {
                log.error("encrypt phone brandId is null phoneFieldProcess parameter: {}, method: {}", JacksonUtil.bean2Map(parameter), JacksonUtil.bean2Json(localMethod));
            }
            fieldsListWithAnnotation.stream().forEach(field -> {
                try {
                    Object value = ReflectUtil.getFieldValue(parameter, field);
                    if (value != null) {
                        if (field.getType().isAssignableFrom(String.class)) {
                            value = stringParameterProcess((String) value, encryptDecryptPhoneEnum);
                        } else if (field.getType().isAssignableFrom(List.class)) {
                            List list = (List) value;
                            value = list.stream().filter(ObjectUtil::isNotNull).map(o ->
                                    o instanceof String? stringParameterProcess((String) o, encryptDecryptPhoneEnum): o
                            ).collect(Collectors.toList());
                        }
                        ReflectUtil.setFieldValue(parameter, field, value);
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
        }
    }

    private void selectForExample(Method method, Object parameter) {
        Class<?> declaringClass = method.getDeclaringClass();
        Method selectByPrimaryKeyMethod = ReflectUtil.getMethodByName(declaringClass, SELECT_BY_PRIMARY_KEY);
        Class<?> poClass = null;
        if (selectByPrimaryKeyMethod != null) {
            poClass = selectByPrimaryKeyMethod.getReturnType();
        }
        List<Field> fieldsListWithAnnotation = FieldUtils.getFieldsListWithAnnotation(poClass, EncryptCardNo.class);
        if (CollUtil.isEmpty(fieldsListWithAnnotation)) {
            return;
        }
        fieldsListWithAnnotation.stream().filter(field -> field.getType().equals(String.class) || field.getType().isAssignableFrom(List.class)).forEach(field -> {
            String name = field.getName();
            // 解密 phone 字段
            List<Object> oredCriteriaList = (List)ReflectUtil.getFieldValue(parameter, ORED_CRITERIA);
            Object oredCriteria = oredCriteriaList.get(0);
            List<Object> criteriaList = (List)ReflectUtil.getFieldValue(oredCriteria, CRITERIA);
            criteriaList.forEach(criteria -> {
                String condition = (String)ReflectUtil.getFieldValue(criteria, CONDITION);
                if (StrUtil.startWith(condition, StrUtil.toUnderlineCase(name))) {
                    try {
                        Object value = ReflectUtil.getFieldValue(criteria, VALUE);
                        if (value != null) {
                            if (value instanceof List) {
                                List<Object> valueList = (List<Object>)value;
                                value = valueList.stream().map(item ->
                                        ObjectUtil.isNotNull(item) && item instanceof String? stringParameterProcess((String)item, PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT): item
                                ).collect(Collectors.toList());
                            } else if (value instanceof String) {
                                value = stringParameterProcess((String)value, PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT);
                            }
                            ReflectUtil.setFieldValue(criteria, VALUE, value);
                        }

                        Object secondValue = ReflectUtil.getFieldValue(criteria, SECOND_VALUE);
                        if (secondValue != null) {
                            if (secondValue instanceof List) {
                                List<Object> valueList = (List<Object>)secondValue;
                                secondValue = valueList.stream().map(item ->
                                    ObjectUtil.isNotNull(item) && item instanceof String? stringParameterProcess((String)item, PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT): item
                                ).collect(Collectors.toList());
                            } else if (secondValue instanceof String) {
                                secondValue = stringParameterProcess((String)secondValue, PhoneEncryptUtil.EncryptDecryptPhoneEnum.ENCRYPT);
                            }
                            ReflectUtil.setFieldValue(criteria, SECOND_VALUE, secondValue);
                        }
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            });
        });
    }

    private static void parseMapForExample(Object parameter, Map<Object, Object> paramMap) {
        // 解密 phone 字段
        List<Object> oredCriteriaList;
        if (parameter instanceof Map) {
            oredCriteriaList = (List<Object>)((Map)parameter).get(ORED_CRITERIA);
        } else {
            oredCriteriaList = (List)ReflectUtil.getFieldValue(parameter, ORED_CRITERIA);
        }
        if (CollUtil.isEmpty(oredCriteriaList)) {
            return;
        }
        Object oredCriteria = oredCriteriaList.get(0);
        List<Object> criteriaList = (List)ReflectUtil.getFieldValue(oredCriteria, CRITERIA);
        criteriaList.forEach(criteria -> {
            String condition = (String)ReflectUtil.getFieldValue(criteria, CONDITION);
            columnNameList.stream().forEach(columnName -> {
                Object value = ReflectUtil.getFieldValue(criteria, VALUE);
                if (StrUtil.startWith(condition, StrUtil.toUnderlineCase(columnName)) && value != null) {
                    if (value instanceof List) {
                        List<Object> valueList = (List<Object>)value;
                        paramMap.put(columnName, valueList.stream().filter(ObjectUtil::isNotNull).findFirst().get());
                    } else if (value instanceof Number) {
                        paramMap.put(columnName, value);
                    }
                }
            });
        });
    }

    private static Long getSysCompanyIdFromObj(Object paramObj) {
        Long sysCompanyId = getLocalCompanyId();
        if (sysCompanyId == null && paramObj != null) {
            HashMap<Object, Object> paramMap = new HashMap<>(16);
            parseParameter(paramObj, paramMap);
            Object sysCompanyIdObj = paramMap.getOrDefault(SYS_COMPANY_ID, paramMap.get(COMPANY_ID));
            if (sysCompanyIdObj != null) {
                sysCompanyId = Long.valueOf(sysCompanyIdObj.toString());
            } else {
                Object sysBrandIdObj = paramMap.getOrDefault(SYS_BRAND_ID, paramMap.get(BRAND_ID));
                if (sysBrandIdObj != null && checkService()) {
                    Long sysBrandId = Long.valueOf(sysBrandIdObj.toString());
                    ResponseData<SysBrandPo> brandByID = SpringContextHolder.getApplicationContext().getBean(BrandServiceRpc.class).getBrandByID(sysBrandId);
                    if (brandByID.getCode() == SysResponseEnum.SUCCESS.getCode() && brandByID.getData() != null) {
                        sysCompanyId = brandByID.getData().getSysCompanyId();
                    }
                }
            }
        }
        if (sysCompanyId != null) {
            setLocalCompanyId(sysCompanyId);
        }
        return sysCompanyId;
    }

    private static void parseParameter(Object paramObj, Map<Object, Object> map) {
        if (paramObj == null) {
            return;
        }
        if (paramObj instanceof Number || paramObj instanceof Boolean || paramObj instanceof Character) {
            return;
        }
        String className = paramObj.getClass().getName();
        if (paramObj instanceof Map) {
            Map<String, Object> paramMap = (Map<String, Object>) paramObj;
            Set<String> keySet = paramMap.keySet();
            if (keySet.contains(ORED_CRITERIA)) {
                parseMapForExample(paramObj, map);
            }
            Iterator iterator = paramMap.keySet().iterator();
            while (iterator.hasNext()) {
                Object next = iterator.next();
                Object nextValue = paramMap.get(next);
                if (nextValue == null) {
                    continue;
                }
                if (nextValue instanceof String || nextValue instanceof Number) {
                    map.put(next, nextValue);
                } else {
                    parseParameter(nextValue, map);
                }
            }
        } else if (paramObj instanceof List) {
            List<Object> list = (List<Object>) paramObj;
            if (CollUtil.isNotEmpty(list)) {
                Object o = list.get(0);
                if (o != null) {
                    parseParameter(o, map);
                }
            }
        } else if (isEndWithExample(className)) {
            parseMapForExample(paramObj, map);
        } else {
            try {
                Map<String, Object> stringObjectMap = JacksonUtil.bean2Map(paramObj);
                if (stringObjectMap != null) {
                    map.putAll(stringObjectMap);
                }
            } catch (Exception e) {
                log.info("parseParameter error:{}", e);
            }
        }
    }

    private static boolean isEndWithCount(String methodName) {
        return methodName.endsWith(COUNT_SUFFIX);
    }

    private static boolean isEndWithExample(String className) {
        return className.endsWith(EXAMPLE_SUFFIX);
    }

    private static void setLocalCompanyId(Long sysCompanyId) {
        if (mapThreadLocal.get() == null) {
            mapThreadLocal.set(new HashMap<>());
        }
        mapThreadLocal.get().put(SYS_COMPANY_ID, sysCompanyId);
    }

    private static void setLocalMethod(Method method) {
        if (mapThreadLocal.get() == null) {
            mapThreadLocal.set(new HashMap<>());
        }
        if (method != null) {
            mapThreadLocal.get().put(METHOD, method);
        }
    }

    private static boolean needEncrypt(Object object) {
        boolean flag = false;
        Long localCompanyId = getSysCompanyIdFromObj(object);
        if (localCompanyId != null) {
            List<Long> sysCompanyIdList = encryptPhoneProperties.getSysCompanyIdList();
            flag = CollUtil.contains(sysCompanyIdList, localCompanyId);
        }
        return flag;
    }

    private static void removeLocalBrandId() {
        Map<String, Object> stringObjectMap = mapThreadLocal.get();
        if (stringObjectMap != null) {
            stringObjectMap.remove(SYS_COMPANY_ID);
            stringObjectMap.remove(METHOD);
        }
    }

    public static Long getLocalCompanyId() {
        Map<String, Object> stringObjectMap = mapThreadLocal.get();
        if (stringObjectMap == null) {
            mapThreadLocal.set(new HashMap<>());
            return null;
        }
        return (Long)stringObjectMap.get(SYS_COMPANY_ID);
    }

    public static Method getLocalMethod() {
        return (Method)mapThreadLocal.get().get(METHOD);
    }

    public static boolean needEncrypt() {
        return needEncrypt(null);
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        // 设置插件参数
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        encryptPhoneProperties = applicationContext.getBean(EncryptPhoneProperties.class);
        KEY = encryptPhoneProperties.getKey();
        IV = encryptPhoneProperties.getIv();
        if (KEY != null && IV != null) {
            log.info("初始化加密私钥成功");
        }
    }

    private static boolean checkService() {
        if (!"centerstage".equals(encryptPhoneProperties.getService())) {
            return true;
        } else {
            return false;
        }
    }
}
