手写一个简单版的SpringMVC

一 写在前面

这是自己实现一个简单的具有SpringMVC功能的小Demo,主要实现效果是;

自己定义的实现效果是通过浏览器地址传一个name参数,打印“my name is”+name参数。不使用SpringMVC,自己定义部分注解,实现DispatcherServlet核心功能,通过这个demo可以加深自己对源码的理解。

先看一下实现效果:

(传入了参数时)

(没有传入参数时)

二  DispatcherServlet流程

  1. 加载配置文件
  2. 扫描所有相关类
  3. 初始化所有相关的类
  4. 自动注入
  5. 初始化HandlerMapping
  6. 等待请求

三 代码回顾

1.首先来看一下Pom文件的依赖:

<dependencies>
  <dependency>
    <groupId>javax.servlet</groupId>
    <artifactId>servlet-api</artifactId>
    <version>2.5</version>
  </dependency>
  <dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-lang3</artifactId>
    <version>3.10</version>
  </dependency>
  <dependency>
    <groupId>org.projectlombok</groupId>
    <artifactId>lombok</artifactId>
    <version>1.18.12</version>
  </dependency>
  <dependency>
    <groupId>ch.qos.logback</groupId>
    <artifactId>logback-core</artifactId>
    <version>1.2.3</version>
  </dependency>
  <dependency>
    <groupId>ch.qos.logback</groupId>
    <artifactId>logback-classic</artifactId>
    <version>1.2.3</version>
  </dependency>
</dependencies>
View Code

依赖比较少,没有spring的依赖,主要就是一个servlet的。


2. 配置文件:

2.1. application.properties文件:

scanPackage=com.qunar.framework.demo
View Code

这是说明要扫描的位置。

 2.2. web.xml文件:

<!DOCTYPE web-app PUBLIC
 "-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN"
 "http://java.sun.com/dtd/web-app_2_3.dtd" >
 
<web-app>
  <display-name>MySpringMVC</display-name>
  <servlet>
    <servlet-name>mvc</servlet-name>
    <servlet-class>com.qunar.framework.webmvc.DispatcherServlet</servlet-class>
    <init-param>
      <param-name>contextConfigLocation</param-name>
      <param-value>/application.properties</param-value>
    </init-param>
    <load-on-startup>1</load-on-startup>
  </servlet>
  <servlet-mapping>
    <servlet-name>mvc</servlet-name>
    <url-pattern>/*</url-pattern>
  </servlet-mapping>
</web-app>

3. 下面是整个工程的目录结构:

4. 自定义注解:

@Controller:

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Controller {
    String value() default "";
}
View Code

@Service:

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Service {
    String value() default "";
}
View Code

@AutoWired:

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
    String value() default "";
}
View Code

@RequestMapping:

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
    String value() default "";
}
View Code

@RequestParam:

@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RequestParam {
    String value() default  "";
}
View Code

5.自己封装的Handler:

public class Handler {
    protected Object controller;
    protected Method method;
    protected Pattern pattern;
    protected Map<String,Integer> paramIndexMap;
 
    public Handler(Object controller, Method method, Pattern pattern) {
        this.controller = controller;
        this.method = method;
        this.pattern = pattern;
        this.paramIndexMap = new HashMap<>();
        putParamIndexMapping(method);
    }
 
    private void putParamIndexMapping(Method method) {
        //获取方法中加了注解的参数
        Annotation[][] annotations = method.getParameterAnnotations();
        for (int i =0; i < annotations.length;i++){
            for (Annotation annotation : annotations[i]){
                if (annotation instanceof RequestParam){
                    String paramName = ((RequestParam) annotation).value();
                    if (!StringUtils.isBlank(paramName)){
                        paramIndexMap.put(paramName,i);
                    }
                }
            }
        }
        //获取方法中的我request和response的参数
        Class<?>[] paramTypes = method.getParameterTypes();
        for (int i = 0; i < paramTypes.length; i++){
            Class<?> paramType = paramTypes[i];
            if (paramType == HttpServletRequest.class || paramType == HttpServletResponse.class){
                paramIndexMap.put(paramType.getName(),i);
            }
        }
    }
}
View Code

6. 自己封装的DispatcherServlet:

@Slf4j
public class DispatcherServlet extends HttpServlet {
    private static final long serialVersionUID = 1L;
    private Properties contextConfig = new Properties();
    private List<String> classNames = new ArrayList<>();
    private Map<String, Object> iocMap = new HashMap<>();
    private List<Handler> handlerMapping = new ArrayList<>();
 
    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
        this.doPost(req, resp);
    }
 
    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
        //等待请求
        try {
            doDispatch(req, resp);
        } catch (Exception exception) {
            resp.getWriter().write("500 Exception");
            log.error("500 Exception. Cause: {}", exception.getMessage());
            exception.printStackTrace();
        }
    }
 
    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
        Handler handler = getHandler(req);
        if (handler == null) {
            //没有匹配上,404
            log.info("404 Not Found");
            resp.getWriter().write("404 Not Found");
            return;
        }
        //获取参数列表
        Class<?>[] parameterTypes = handler.method.getParameterTypes();
        //保存所有需要自动赋值的参数值
        Object[] parameterValues = new Object[parameterTypes.length];
 
        Map<String, String[]> parameterMap = req.getParameterMap();
        for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
            String value = Arrays.toString(entry.getValue()).replaceAll("\[|\]", "").replaceAll("/+", "/");
            log.info(value);
            //如果找到了匹配的值,就填充
            if (!handler.paramIndexMap.containsKey(entry.getKey())) {
                continue;
            }
            Integer index = handler.paramIndexMap.get(entry.getKey());
            parameterValues[index] = convert(parameterTypes[index], value);
        }
        //设置方法中的request对象和response对象
        Integer reqIndex = handler.paramIndexMap.get(HttpServletRequest.class.getName());
        Integer respIndex = handler.paramIndexMap.get(HttpServletResponse.class.getName());
        parameterValues[reqIndex] = req;
        parameterValues[respIndex] = resp;
        handler.method.invoke(handler.controller, parameterValues);
    }
 
    private Object convert(Class<?> parameterType, String value) {
        if (parameterType == Integer.class) {
            return Integer.valueOf(value);
        }
        return value;
    }
 
    private Handler getHandler(HttpServletRequest req) {
        if (handlerMapping.isEmpty()) {
            return null;
        }
        String requestURI = req.getRequestURI();
        String contextPath = req.getContextPath();
        requestURI = requestURI.replace(contextPath, "").replaceAll("/+", "/");
        for (Handler handler : handlerMapping) {
            Matcher matcher = handler.pattern.matcher(requestURI);
            if (!matcher.matches()) {
                continue;
            }
            return handler;
        }
        return null;
    }
 
    @Override
    public void init(ServletConfig config) {
        //从这里开始启动:
        //加载配置文件
        loadConfig(config.getInitParameter("contextConfigLocation"));
        //扫描相关类
        doScanner(contextConfig.getProperty("scanPackage"));
        //初始化相关类
        try {
            doInstance();
        } catch (Exception exception) {
            log.error("Execute doInstance method fail.");
            exception.printStackTrace();
        }
        //自动注入
        doAutowired();
        //初始化HandlerMapping
        initHandlerMapping();
    }
 
    private void initHandlerMapping() {
        if (iocMap.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : iocMap.entrySet()) {
            Class<?> clazz = entry.getValue().getClass();
            if (!clazz.isAnnotationPresent(Controller.class)) {
                continue;
            }
            String baseUrl = "";
            if (clazz.isAnnotationPresent(RequestMapping.class)) {
                RequestMapping requestMapping = clazz.getAnnotation(RequestMapping.class);
                baseUrl = requestMapping.value();
            }
            //扫描所有的公共方法
            for (Method method : clazz.getMethods()) {
                if (!method.isAnnotationPresent(RequestMapping.class)) {
                    continue;
                }
                RequestMapping requestMapping = method.getAnnotation(RequestMapping.class);
                String regex = ("/" + baseUrl + requestMapping.value()).replaceAll("/+", "/");
                Pattern pattern = Pattern.compile(regex);
                handlerMapping.add(new Handler(entry.getValue(), method, pattern));
                log.info("Mapping: {}.{}", regex, method);
            }
        }
    }
 
    private void doAutowired() {
        if (iocMap.isEmpty()) {
            return;
        }
        //循环所有的类,对需要自动赋值的属性进行赋值
        for (Map.Entry<String, Object> entry : iocMap.entrySet()) {
            Field[] fields = entry.getValue().getClass().getDeclaredFields();
            for (Field field : fields) {
                if (!field.isAnnotationPresent(Autowired.class)) {
                    continue;
                }
                Autowired autowired = field.getAnnotation(Autowired.class);
                String beanName = autowired.value();
                if (beanName != null) {
                    beanName = beanName.trim();
                }
                if (StringUtils.isBlank(beanName)) {
                    beanName = field.getType().getName();
                }
                field.setAccessible(true);
                try {
                    field.set(entry.getValue(), iocMap.get(beanName));
                } catch (IllegalAccessException e) {
                    log.error("AutoWired fail,beanName: {}", beanName);
                    e.printStackTrace();
                    continue;
                }
            }
        }
    }
 
    private void doInstance() throws Exception {
        if (classNames.isEmpty()) {
            return;
        }
        for (String className : classNames) {
            Class<?> clazz = Class.forName(className);
            //如果自定义了名字,就优先使用自己的名字,否则默认是小写(这里就不默认首字母为小写了
            if (clazz.isAnnotationPresent(Controller.class)) {
                Controller controller = clazz.getAnnotation(Controller.class);
                String beanName = controller.value();
                if (StringUtils.isBlank(beanName)) {
                    beanName = clazz.getName().toLowerCase();
                }
                Object instance = clazz.newInstance();
                iocMap.put(beanName, instance);
            } else if (clazz.isAnnotationPresent(Service.class)) {
                Service service = clazz.getAnnotation(Service.class);
                String beanName = service.value();
                if (StringUtils.isBlank(beanName)) {
                    beanName = clazz.getName().toLowerCase();
                }
                Object instance = clazz.newInstance();
                iocMap.put(beanName, instance);
                //根据接口类型来赋值
                for (Class<?> clazzInterface : clazz.getInterfaces()) {
                    iocMap.put(clazzInterface.getName(), instance);
                }
            } else {
                continue;
            }
        }
    }
 
    private void doScanner(String scanPackage) {
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\.", "/"));
        File classDir = new File(url.getFile());
        for (File file : classDir.listFiles()) {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else {
                String className = scanPackage + "." + file.getName().replace(".class", "");
                classNames.add(className);
            }
        }
    }
 
    private void loadConfig(String location) {
        InputStream inputStream = this.getClass().getResourceAsStream(location);
        try {
            contextConfig.load(inputStream);
        } catch (IOException e) {
            log.error("Load fail, location: {}", location);
            e.printStackTrace();
        } finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                } catch (IOException e) {
                    log.error("Close fail, inputStream: {}", inputStream);
                    e.printStackTrace();
                }
            }
        }
    }
}
View Code

这个类就是最核心的类,它做了SpringMVC的事情。

7.下面是验证自己SpringMVC是否可用的时候了,自己写了service和controller:

7.1 service:

public class DemoServiceImpl implements IDemoService {
    @Override
    public String get(String name) {
        return "my name is " + name;
    }
}
View Code

7.2 controller:

@Controller
@RequestMapping("/demo")
@Slf4j
public class DemoController {
    @Autowired
    IDemoService service;
 
    @RequestMapping("/get")
    public void get(HttpServletRequest req, HttpServletResponse resp, @RequestParam("name") String name) {
        String res = service.get(name);
        try {
            resp.setContentType("text/html;charset=UTF-8");
            resp.getWriter().println(res);
        } catch (IOException e) {
            log.info(e.getMessage());
            e.printStackTrace();
        }
    }
}
View Code

再结合开头贴出来的图片,验证了自己的这个SpringMVC是可以使用的。

四 最后

这里只要实现了SpringMVC最简单的功能而已。这只是一个加深自己对SpringMVC的mapping映射流程的理解而已,真正的SpringMVC当然远不止如此简单。

Demo的github地址:https://github.com/Happy-Ape/Spring

原文地址:https://www.cnblogs.com/ericz2j/p/13553719.html