SPI扩展机制在框架中的使用

java也有自己的SPI实现,但是有很多小毛病,比如:会一次性加载所有扩展实现,不能支持一些复杂的元数据表达,据说多了类加载器同时加载会有并发问题(没有考证过)。所以很多框架都提供了SPI机制供使用者自己扩展,例如Dubbo,使用SPI还可以实现按需加载扩展点。之前看过Dubbo的SPI实现,其实它的整个核心功能都是围绕SPI来实现的,所以显得很复杂。接下来看一个轻量级的SPI实现——来源于公司一个生产级的框架。不多说,直接上代码:

两个注解的定义:

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface Spi {
    Scope scope() default Scope.PROTOTYPE;
}
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface SpiMeta {
    String name() default "";
}

核心的扩展点加载器:

public class ExtensionLoader<T> {
    private static final Logger logger = LoggerFactory.getLogger(ExtensionLoader.class);
    private ConcurrentMap<String, T> singletonInstances = null;
    //存放所有扩展点类的集合,key是扩展名,每一个扩展文件中可以定义多个扩展类
    private ConcurrentMap<String, Class<T>> extensionClasses = null;
    //加载该路径下所有扩展类
    private static final String SPI_LOCATION = "META-INF/services/";
    private ClassLoader classLoader;
    private Class<T> type;
    private volatile boolean init = false;
    //存储扩展点对象和其对应的扩展加载器
    private static final Map<Class<?>, ExtensionLoader<?>> extensionLoaders = new ConcurrentHashMap<>();

    //通过构造方法设置类加载器,使用Thread.currentThread().getContextClassLoader(),如果没有setContextClassLoader()则为系统类加载器AppClassLoader
    private ExtensionLoader(Class<T> type) {
        this(type, Thread.currentThread().getContextClassLoader());
    }

    private ExtensionLoader(Class<T> type, ClassLoader classLoader) {
        this.type = type;
        this.classLoader = classLoader;
    }

    @SuppressWarnings("unchecked")
    public static <T> ExtensionLoader<T> getExtensionLoader(Class<T> type) {
        //根据扩展点类型获取对应的扩展加载器
        ExtensionLoader<T> loader = (ExtensionLoader<T>) extensionLoaders.get(type);
        if (loader == null) {
            //如果没有该扩展加载器,则进行初始化
            loader = initExtensionLoader(type);
        }
        return loader;
    }

    //初始化方法是静态的且加了锁,相当于在类对象上加锁,可以保证同时只有一个扩展器的初始化操作
    @SuppressWarnings("unchecked")
    private static synchronized <T> ExtensionLoader<T> initExtensionLoader(Class<T> type) {
        ExtensionLoader<T> loader = (ExtensionLoader<T>) extensionLoaders.get(type);
        if (loader == null) {
            loader = new ExtensionLoader<>(type);
            extensionLoaders.putIfAbsent(type, loader);
            loader = (ExtensionLoader<T>) extensionLoaders.get(type);
        }
        return loader;
    }

    @SuppressWarnings("unchecked")
    public List<T> getExtensions() {
        checkAndInit();
        List<T> extensions = new ArrayList<>(extensionClasses.size());
        for (Map.Entry<String, Class<T>> entry : extensionClasses.entrySet()) {
            extensions.add(getExtension(entry.getKey()));
        }
        extensions.sort(new ExtensionOrderComparator<T>());
        return extensions;
    }

    //根据扩展名获取扩展点对象
    public T getExtension(String name) {
        checkAndInit();
        if (name == null) {
            return null;
        }
        try {
            //注意,@Spi是使用在扩展点接口上的,@SpiMeta是使用在实现类上的
            Spi spi = type.getAnnotation(Spi.class);
            //单例类型
            if (spi.scope() == Scope.SINGLETON) {
                return getSingletonInstance(name);
            } else {
                //原型类型
                Class<T> clz = extensionClasses.get(name);
                if (clz == null) {
                    return null;
                }
                return clz.newInstance();
            }
        } catch (Exception e) {
            new RuntimeException(type.getName() + ":Error when getExtension " + name, e);
        }
        return null;
    }

    @SuppressWarnings("unchecked")
    private T getSingletonInstance(String name) throws InstantiationException, IllegalAccessException {
        T obj = singletonInstances.get(name);
        if (obj != null) {
            return obj;
        }
        Class<T> clz = extensionClasses.get(name);
        if (clz == null) {
            return null;
        }
        //加锁对象为集合对象,确保只有一个线程能创建扩展点对象
        synchronized (singletonInstances) {
            obj = singletonInstances.get(name);
            if (obj != null) {
                return obj;
            }
            obj = clz.newInstance();
            singletonInstances.put(name, obj);
        }
        return obj;
    }

    private void checkAndInit() {
        //init被volatile修饰,确保只有一个线程进行初始化
        if (!init) {
            loadExtensionClasses();
        }
    }

    //这里在方法级别加锁,锁对象是扩展点对应的扩展类加载器对象
    private synchronized void loadExtensionClasses() {
        if (init) {
            return;
        }
        //将META-INF/services/目录下的扩展点加载进集合中
        extensionClasses = loadExtensionClasses(SPI_LOCATION);
        singletonInstances = new ConcurrentHashMap<>();
        init = true;
    }

    private ConcurrentMap<String, Class<T>> loadExtensionClasses(String prefix) {
        //根据前缀和类的全限定名来读取文件,所以这里注意扩展点的文件名称必须是类的全限定名
        String fullName = prefix + type.getName();
        List<String> classNames = new ArrayList<String>();
        try {
            Enumeration<URL> urls;
            if (classLoader == null) {
                urls = ClassLoader.getSystemResources(fullName);
            } else {
                urls = classLoader.getResources(fullName);
            }
            if (urls == null || !urls.hasMoreElements()) {
                return new ConcurrentHashMap<>();
            }
            while (urls.hasMoreElements()) {
                URL url = urls.nextElement();
                //解析类
                parseUrl(type, url, classNames);
            }
        } catch (Exception e) {
            throw new RuntimeException("ExtensionLoader loadExtensionClasses error, prefix: " + prefix + " type: " + type.getClass(), e);
        }
        //将类加载进内存,并放入集合中
        return loadClass(classNames);
    }


    @SuppressWarnings("unchecked")
    private ConcurrentMap<String, Class<T>> loadClass(List<String> classNames) {
        ConcurrentMap<String, Class<T>> map = new ConcurrentHashMap<String, Class<T>>();
        for (String className : classNames) {
            try {
                Class<T> clz;
                if (classLoader == null) {
                    //classLoader为空,使用加载当前类的类加载器进行加载
                    clz = (Class<T>) Class.forName(className);
                } else {
                    clz = (Class<T>) Class.forName(className, true, classLoader);
                }
                checkExtensionType(clz);
                String spiName = getSpiName(clz);
                if (map.containsKey(spiName)) {
                    new RuntimeException(clz.getName() + ":Error spiName already exist " + spiName);
                } else {
                    map.put(spiName, clz);
                }
            } catch (Exception e) {
                logger.error(type.getName() + ":" + "Error load spi class", e);
            }
        }
        return map;

    }

    private void checkExtensionType(Class<T> clz) {
        checkClassPublic(clz);

        checkConstructorPublic(clz);

        checkClassInherit(clz);
    }

    private void checkClassPublic(Class<T> clz) {
        if (!Modifier.isPublic(clz.getModifiers())) {
            new RuntimeException(clz.getName() + ":Error is not a public class");
        }
    }

    private void checkClassInherit(Class<T> clz) {
        if (!type.isAssignableFrom(clz)) {
            new RuntimeException(clz.getName() + ":Error is not instanceof " + type.getName());
        }
    }

    private void checkConstructorPublic(Class<T> clz) {
        Constructor<?>[] constructors = clz.getConstructors();

        if (constructors == null || constructors.length == 0) {
            new RuntimeException(clz.getName() + ":Error has no public no-args constructor");
        }

        for (Constructor<?> constructor : constructors) {
            if (Modifier.isPublic(constructor.getModifiers()) && constructor.getParameterTypes().length == 0) {
                return;
            }
        }

        new RuntimeException(clz.getName() + ":Error has no public no-args constructor");
    }



    public String getSpiName(Class<?> clz) {
        SpiMeta spiMeta = clz.getAnnotation(SpiMeta.class);
        //如果SpiMeta中没有定义name属性,则使用类型,如@SpiMeta(name = "coreSamplePrinter")
        return (spiMeta != null && !"".equals(spiMeta.name())) ? spiMeta.name() : clz.getSimpleName();
    }

    private void parseUrl(Class<T> type, URL url, List<String> classNames) throws ServiceConfigurationError {
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = url.openStream();
            reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
            String line;
            int indexNumber = 0;
            while ((line = reader.readLine()) != null) {
                indexNumber++;
                parseLine(type, url, line, indexNumber, classNames);
            }
        } catch (Exception x) {
            logger.error(type.getName() + ":" + "Error reading spi configuration file", x);
        } finally {
            try {
                if (reader != null) {
                    reader.close();
                }
                if (inputStream != null) {
                    inputStream.close();
                }
            } catch (IOException y) {
                logger.error(type.getName() + ":" + "Error closing spi configuration file", y);
            }
        }
    }

    private void parseLine(Class<T> type, URL url, String line, int lineNumber, List<String> names) throws IOException,
            ServiceConfigurationError {
        int ci = line.indexOf('#');  //可以使用#在扩展文件后写一些说明
        if (ci >= 0) {
            line = line.substring(0, ci);
        }
        line = line.trim();
        if (line.length() <= 0) {
            return;
        }
        if ((line.indexOf(' ') >= 0) || (line.indexOf('	') >= 0)) {
            throw new RuntimeException(type.getName() + ": " + "Illegal spi configuration-file syntax");
        }
        int cp = line.codePointAt(0);
        if (!Character.isJavaIdentifierStart(cp)) {
            throw new RuntimeException(type.getName() + ": " + url + ": " + line + ": " + "Illegal spi provider-class name: " + line);
        }
        for (int i = Character.charCount(cp); i < line.length(); i += Character.charCount(cp)) {
            cp = line.codePointAt(i);
            if (!Character.isJavaIdentifierPart(cp) && (cp != '.')) {
                throw new RuntimeException(type.getName() + ": " + url + ": " + line + ": " + "Illegal spi provider-class name: " + line);
            }
        }
        if (!names.contains(line)) {
            names.add(line);
        }
    }
}

主要的功能点和细节都有注释,就不多说了。接下来看一些应用吧。

注解的使用:

@Spi(scope = Scope.SINGLETON)
public interface HealthPrinter {
    void print(Set<HealthStats> healthStats, String timestamp);
}
@SpiMeta(name = "jedisClusterHealthPrinter")
public class JedisClusterHealthPrinter extends AbstractHealthPrinter {

扩展文件的使用:

                   

总结:SPI只是一种思想,可以根据实际需要定制化实现。

原文地址:https://www.cnblogs.com/jing-yi/p/14379597.html