JAVA项目实战-Spring的DI和IOC加载原理

本节介绍Spring的XML和注解加载Bean,手写简单的源码,仅供参考

/**
 * @description: spring的上下文
 * @author: ZhuCJ  
 * @date: 2020-08-27 12:32
 */
public class SpringContext implements BaseFactory{

    /**
     * 指定扫描的包名
     */
    private String packerName;

    /**
     * 指定spring 配置Bean的xml位置
     */
    private String[] xmlPath;


    public SpringContext(String packerName) {
        this.packerName = packerName;
    }

    public SpringContext(String[] xmlPath) {
        this.xmlPath = xmlPath;
    }

    public SpringContext(String packerName, String[] xmlPath) {
        this.packerName = packerName;
        this.xmlPath = xmlPath;
    }


    @Override
    public Object getBean(String beanName) {
        if (Objects.nonNull(this.xmlPath) && xmlPath.length>0){
            //加载XML配置Bean
            loadXml();
        }

        if (!StringUtils.isEmpty(this.packerName)){
            //加载注解配置Bean
            loadAnnotation();
        }

        return BEAN_MAP.get(beanName) ;
    }


    /**
     * 创建注解加载工厂
     */
    public void loadAnnotation(){
       new AnnotationBeanFactory(this.packerName);
    }


    /**
     * 创建XML方式工厂
     */
    public void loadXml(){
        for (String xmlPath:this.xmlPath){
             new XmlBeanFactory(xmlPath);
        }
    }

}

/**
* @description: * @author: ZhuCJ * @date: 2020-08-27 10:16 */ public class AnnotationBeanFactory implements BaseFactory { private static final Logger logger = LoggerFactory.getLogger(AnnotationBeanFactory.class); private String packerName ; private static final String EXT = "class"; @Override public Object getBean(String beanName) { return BEAN_MAP.get(beanName); } public AnnotationBeanFactory(String packerName){ this.packerName = packerName; //加载注解Bean loadBean(); //加载注入的属性Bean loadInjectBean(); } /** * 加载bean到容器中 */ public void loadBean(){ //读取包名的路径 String packerPath = null; try { packerPath = getPkgPath(packerName); } catch (UnsupportedEncodingException e) { logger.info("文件路径编码异常:{}",e.getMessage()); throw new RuntimeException("packerName path error"); } logger.info("扫描文件目录的路径:{}", packerPath); // 查找包含Component注解的类 Map<Class<? extends Annotation>, Set<Class<?>>> classesMap = scanClassesByAnnotations(packerName, packerPath, true, Arrays.asList(ServiceTest.class)); if (classesMap.size() == 0){ logger.error("目录:{}下,未获取到需要加载的类", packerPath); return; } //标记的反射对象 Set<Class<?>> classSet = new HashSet<>(); classesMap.forEach((k, v) -> { classSet.addAll(v); }); //默认设置类名为类名小写 for (Class<?> classObj:classSet){ Object object = null; try { object = classObj.newInstance(); } catch (InstantiationException e) { throw new RuntimeException(classObj.getSimpleName()+ " create error"); } catch (IllegalAccessException e) { throw new RuntimeException(classObj.getSimpleName()+ " create error"); } BEAN_MAP.put(StringUtils.uncapitalize(classObj.getSimpleName()),object); } } /** * 加载注入的Bean属性 */ public void loadInjectBean(){ BEAN_MAP.forEach((k,v)->{ setAttributeValue(v); }); } /** * 根据包名获取包的URL * @param pkgName com.demo.controller * @return */ public static String getPkgPath(String pkgName) throws UnsupportedEncodingException { String pkgDirName = pkgName.replace('.', File.separatorChar); URL url = Thread.currentThread().getContextClassLoader().getResource(pkgDirName); return url == null ? null : URLDecoder.decode(url.getFile(), "UTF-8"); } /** * 获取指定包下包含指定注解的所有类对象的集合 * @param pkgName 包名(com.demo.controller) * @param pkgPath 包路径(/Users/xxx/workspace/java/project/out/production/classes/com/demo/controller) * @param recursive 是否递归遍历子目录 * @param targetAnnotations 指定注解 * @return 以注解和对应类集合构成的键值对 */ public static Map<Class<? extends Annotation>, Set<Class<?>>> scanClassesByAnnotations( String pkgName, String pkgPath, final boolean recursive, List<Class<? extends Annotation>> targetAnnotations){ Map<Class<? extends Annotation>, Set<Class<?>>> resultMap = new HashMap<>(16); Collection<File> allClassFile = getAllClassFile(pkgPath, recursive); for (File curFile : allClassFile){ try { Class<?> curClass = getClassObj(curFile, pkgPath, pkgName); for (Class<? extends Annotation> annotation : targetAnnotations){ if (curClass.isAnnotationPresent(annotation)){ if (!resultMap.containsKey(annotation)){ resultMap.put(annotation, new HashSet<Class<?>>()); } resultMap.get(annotation).add(curClass); } } } catch (ClassNotFoundException e) { logger.error("load class fail", e); } } return resultMap; } /** * 遍历指定目录下所有扩展名为class的文件 * @param pkgPath 包目录 * @param recursive 是否递归遍历子目录 * @return */ private static Collection<File> getAllClassFile(String pkgPath, boolean recursive){ File fPkgDir = new File(pkgPath); if (!(fPkgDir.exists() && fPkgDir.isDirectory())){ logger.error("the directory to package is empty: {}", pkgPath); return null; } return FileUtils.listFiles(fPkgDir, new String[]{EXT}, recursive); } /** * 加载类 * @param file * @param pkgPath * @param pkgName * @return * @throws ClassNotFoundException */ private static Class<?> getClassObj(File file, String pkgPath, String pkgName) throws ClassNotFoundException{ // 考虑class文件在子目录中的情况 String absPath = file.getAbsolutePath().substring(0, file.getAbsolutePath().length() - EXT.length() - 1); String className = absPath.substring(pkgPath.length()).replace(File.separatorChar, '.'); className = className.startsWith(".") ? pkgName + className : pkgName + "." + className; return Thread.currentThread().getContextClassLoader().loadClass(className); } /** * 属性赋值 * @param object */ private static void setAttributeValue(Object object){ Class<?> aClass = object.getClass(); Field[] declaredFields = aClass.getDeclaredFields(); for (Field field:declaredFields){ if (field.isAnnotationPresent(AutowiredTest.class)){ //默认取属性值类型小写为 BeanId String simpleName = field.getType().getSimpleName(); Object obj = BEAN_MAP.get(StringUtils.uncapitalize(simpleName)); //允许私有属性赋值 field.setAccessible(true); try { field.set(object,obj); } catch (IllegalAccessException e) { throw new RuntimeException(field.getName() +" attribute set value exception"); } } } } public static void main(String[] args) { BaseFactory baseFactory = new AnnotationBeanFactory("com.spring"); BaseFactory baseFactory1 = new XmlBeanFactory("/spring/test.xml"); Body body = (Body) baseFactory.getBean("body"); System.out.println("加载类:"+BaseFactory.BEAN_MAP.size()+"个"); Object apple = baseFactory.getBean("apple"); System.out.println(body); System.out.println(apple); Order order = new Order(); for (Field field:order.getClass().getDeclaredFields()){ System.out.println(field.getName()); System.out.println(field.getType().getSimpleName()); } } }
/**
 * @description: Xml方式创建Bean
 * @author: ZhuCJ 
 * @date: 2020-08-26 12:43
 */
public class XmlBeanFactory implements BaseFactory {

    /**
     * *******XML形式注册Bean
     * 1.指定Resources资源 Xml文件位置
     * 2.加载Xml文件Document对象 拿到id 和classes值
     * 3.反射创建对象
     */
    private String filePath;


    public XmlBeanFactory(String xmlPath){
        this.filePath = xmlPath;
        loadBean();
    }

    @Override
    public Object getBean(String beanName) {
        return  BEAN_MAP.get(beanName);
    }


    public void loadBean(){
        //读取resource资源的路径
        String sysPath = ClassUtils.getDefaultClassLoader().getResource("").getPath();
        String path = sysPath + filePath;
        //dom4j解析XML文件
        SAXReader saxReader = new SAXReader();
        Document read ;
        try {
            read = saxReader.read(new File(path));
        } catch (Exception e) {
            throw new RuntimeException("file No Find");
        }
        Element root;
        Element rootElement = read.getRootElement();
        for (Iterator i = rootElement.elementIterator("bean");i.hasNext();){
            root =(Element) i.next();
            Attribute id = root.attribute("id");
            Attribute aClass = root.attribute("class");
            //利用反射创建对象
            Class<?> beanClass;
            try {
                beanClass = Class.forName(aClass.getText());
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(id.getText()+"class not Found");
            }
            Object object = null;
            try {
                object = beanClass.newInstance();
            } catch (InstantiationException e) {
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
            BeanInfo beanInfo = null;
            try {
                //获取bean对象信息
                beanInfo = Introspector.getBeanInfo(beanClass);
            } catch (IntrospectionException e) {
                e.printStackTrace();
            }
            //ben对象的属性描述信息
            PropertyDescriptor[] propertyDescriptors = beanInfo.getPropertyDescriptors();
            for (Iterator k =  root.elementIterator("property");k.hasNext();){
                Element propertyElem =(Element) k.next();
                Attribute name = propertyElem.attribute("name");
                Attribute value = propertyElem.attribute("value");
                //判断属性名称,是否和name相等
                for (PropertyDescriptor desc:propertyDescriptors){
                    if (desc.getName().equalsIgnoreCase(name.getText())){
                        Method writeMethod = desc.getWriteMethod();
                        try {
                            //赋值
                            writeMethod.invoke(object,value.getValue());
                        } catch (IllegalAccessException e) {
                            e.printStackTrace();
                        } catch (InvocationTargetException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }

            BEAN_MAP.put(id.getText(),object);
        }

    }

    public static void main(String[] args) {
        BaseFactory baseFactory = new XmlBeanFactory("/spring/test.xml");
        System.out.println(BaseFactory.BEAN_MAP.size());
    }
}
/**
 * @description: 基础工厂
 * @author: ZhuCJ 
 * @date: 2020-08-26 12:42
 */
public interface BaseFactory {


    /** 静态存放Bean的容器 */
    Map<String,Object> BEAN_MAP = new ConcurrentHashMap();


    /**
     * 获取bean
     * @param beanName
     * @return
     */
    Object getBean(String beanName);
}
/**
 * @description: 订单服务
 * @author: ZhuCJ
 * @date: 2020-08-27 13:18
 */
public interface OrderService {


    /**
     * 通过ID查询订单信息
     * @param id
     * @return
     */
    Order selectById (String id);
}
/**
 * @description: 订单服务
 * @author: ZhuCJ
 * @date: 2020-08-27 13:18
 */
@ServiceTest
public class OrderServiceImpl implements OrderService {

    @AutowiredTest
    private Order order;


    @Override
    public Order selectById(String id) {
        return order;
    }


}
/**
 * @description: 模拟 spring中属性Autowired注解
 * @author: ZhuCJ
 * @date: 2020-08-27 11:56
 */
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface AutowiredTest {
}
/**
 * @description: 模拟spring的标记Service层注解
 * @author: ZhuCJ
 * @date: 2020-08-27 10:17
 */
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface ServiceTest {

}
/**
 * @description:
 * @author: ZhuCJ  80004071
 * @date: 2020-08-26 19:09
 */
@Data
@AllArgsConstructor
@NoArgsConstructor
@ToString
public class Order {

    private String id;

    private Date createTime;

    private String orderNo;

}
/**
 * @description:
 * @author: ZhuCJ  80004071
 * @date: 2020-08-11 12:36
 */
public class Main {

    public static void main(String[] args) {
        SpringContext springContext = new SpringContext("com.spring",new String[]{"/spring/test.xml"});
        OrderService orderServiceImpl =(OrderService) springContext.getBean("orderServiceImpl");
        Order order = orderServiceImpl.selectById("1");
        System.out.println(order);
    }
}

测试结果

我已经被创建
我已经被创建
Order(id=12121212, createTime=null, orderNo=T21000)
加载到容器中Bean数量:10

--------------------------------------------------------------------

原文地址:https://www.cnblogs.com/zhucj-java/p/13598673.html