package com.bizvane.crypto.advice;

import com.bizvane.crypto.annotation.ResponseEncryptField;
import com.bizvane.crypto.utils.SM3Utils;
import com.bizvane.crypto.utils.SM4Utils;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.MethodParameter;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice;

import java.lang.reflect.Field;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.TimeUnit;

@ControllerAdvice
public class ResponseEncryptAdvice implements ResponseBodyAdvice<Object> {
    @Value("${spring.crypto.header-key:deviceId}")
    private String cryptoKey;

    @Autowired
    private HttpServletRequest request;


    private static final int MAX_RECURSION_DEPTH = 10; // 最大递归深度限制


    private static final Cache<Class<?>, Field[]> FIELD_CACHE =
            Caffeine.newBuilder().maximumSize(1000) // 最大条目数
                    .expireAfterAccess(30, TimeUnit.MINUTES) // 访问后 30 分钟过期
                    .weakKeys() // 使用弱引用存储键
                    .build();

    @PostConstruct
    public void init() {
        System.out.println("ResponseEncryptAdvice inited.");
    }

    @Override
    public boolean supports(MethodParameter returnType, Class<? extends HttpMessageConverter<?>> converterType) {
        return true;
    }

    @Override
    public Object beforeBodyWrite(Object body, MethodParameter returnType, MediaType selectedContentType,
                                  Class<? extends HttpMessageConverter<?>> selectedConverterType, ServerHttpRequest serverHttpRequest,
                                  ServerHttpResponse response) {
        if (body == null) {
            return null;
        }
        // 从请求头中获取 deviceId
        String deviceId = request.getHeader(cryptoKey);
        if(deviceId== null || deviceId.trim().isEmpty()){
            return body;
        }
        // 生成 SM4 密钥
        String keyHex= SM3Utils.generateKeyHex(deviceId);
        // 递归处理返回的对象
        processFields(body, keyHex);
        return body;
    }


    private void processFields(Object object, String keyHex) {
        processFieldsInternal(object, keyHex, new HashSet<>(), 0);
    }


    private void processFieldsInternal(Object object, String keyHex, Set<Object> processedObjects, int recursionDepth) {
        if (object == null || processedObjects.contains(object)) {
            return;
        }
        if (recursionDepth > MAX_RECURSION_DEPTH) {
            return;
        }

        processedObjects.add(object);
        Class<?> clazz = object.getClass();

        // 基本类型或包装类不做处理
        if (isSimpleValueType(clazz)) {
            return;
        }

        try {
            // 处理集合类型
            if (object instanceof Collection<?>) {
                for (Object item : (Collection<?>) object) {
                    processFieldsInternal(item, keyHex, processedObjects, recursionDepth + 1);
                }
                return;
            }

            // 处理数组类型
            if (clazz.isArray()) {
                for (int i = 0; i < java.lang.reflect.Array.getLength(object); i++) {
                    Object item = java.lang.reflect.Array.get(object, i);
                    processFieldsInternal(item, keyHex, processedObjects, recursionDepth + 1);
                }
                return;
            }

            // 处理 Map 类型
            if (object instanceof Map<?, ?>) {
                for (Object value : ((Map<?, ?>) object).values()) {
                    processFieldsInternal(value, keyHex, processedObjects, recursionDepth + 1);
                }
                return;
            }

            // 处理普通对象字段
            Field[] fields = getDeclaredFields(clazz); // 使用缓存获取字段
            for (Field field : fields) {
                field.setAccessible(true);
                Object value = field.get(object);
                ResponseEncryptField annotation = field.getAnnotation(ResponseEncryptField.class);
                if (annotation != null && value instanceof String && !StringUtils.isEmpty(value)) {
                    String encryptedText = SM4Utils.encryptSM4((String) value, keyHex);
                    field.set(object, encryptedText);
                }
                // 递归处理嵌套对象
                processFieldsInternal(value, keyHex, processedObjects, recursionDepth + 1);
            }
        } catch (Exception e) {
            throw new RuntimeException("Failed to process the field", e);
        }
    }

    // 判断是否为简单类型（基本类型、String、包装类等）
    private boolean isSimpleValueType(Class<?> clazz) {
        return clazz.isPrimitive() ||
                clazz.equals(String.class) || Date.class.isAssignableFrom(clazz) || LocalDateTime.class.isAssignableFrom(clazz) ||
                Number.class.isAssignableFrom(clazz) || Boolean.class.isAssignableFrom(clazz) ||
                Character.class.isAssignableFrom(clazz);
    }


    private Field[] getDeclaredFields(Class<?> clazz) {
        return FIELD_CACHE.get(clazz, Class::getDeclaredFields);
    }
}
