package com.bizvane.crypto.aspect;

import com.bizvane.crypto.annotation.RequestDecryptField;
import com.bizvane.crypto.utils.SM3Utils;
import com.bizvane.crypto.utils.SM4Utils;
import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.HttpServletRequest;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.lang.reflect.Field;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

@Aspect
@Component
public class RequestDecryptAspect {
    @Value("${spring.crypto.header-key:deviceId}")
    private String cryptoKey;
    @Autowired
    private HttpServletRequest request;

    @PostConstruct
    public void init() {
        System.out.println("RequestDecryptAspect inited.");
    }

    @Pointcut("execution(* com.bizvane..*Controller.*(..))")
    public void controllerMethods() {
    }

    @Before("controllerMethods()")
    public void replaceTextInFields(JoinPoint joinPoint) throws IllegalAccessException {
        // 从请求头中获取 deviceId
        String deviceId = request.getHeader(cryptoKey);
        if (deviceId == null || deviceId.trim().isEmpty()) {
            return;
        }
        // 生成 SM4 密钥
        String keyHex = SM3Utils.generateKeyHex(deviceId);
        // 处理方法参数
        Object[] args = joinPoint.getArgs();
        for (Object arg : args) {
            processObjectFields(arg, keyHex, new HashSet<>());
        }
    }

    private void processObjectFields(Object obj, String keyHex, Set<Object> visited) throws IllegalAccessException {
        if (obj == null || visited.contains(obj)) {
            return;
        }
        visited.add(obj);
        // 获取对象的字段
        Field[] fields = obj.getClass().getDeclaredFields();
        for (Field field : fields) {
            field.setAccessible(true);
            Object value = field.get(obj);
            // 处理字符串字段
            if (field.isAnnotationPresent(RequestDecryptField.class) && value instanceof String && !StringUtils.isEmpty(value)) {
                String decryptedText = SM4Utils.decryptSM4((String) value, keyHex);
                field.set(obj, decryptedText);
            }
            // 处理嵌套对象字段
            if (value != null && !field.getType().isPrimitive() && !(value instanceof String)) {
                if (value instanceof Collection<?>) {
                    for (Object item : (Collection<?>) value) {
                        processObjectFields(item, keyHex, visited);
                    }
                } else {
                    processObjectFields(value, keyHex, visited);
                }
            }
        }
    }
}
