Mybatis之拦截器实现及原理

为什么要有插件

可以在映射语句执行前后加一些自定义的操作,比如缓存、分页等

可以拦截哪些方法

Executor:update、query、flushStatements、commit、rollback、getTransaction、close、isClosed。
实现类:SimpleExecutor/BatchExecutor/ReuseExecutor/CachingExecutor

ParameterHandler:getParameterObject、setParameters。
实现类:DefaultParameterHandler

ResultSetHandler:handleResultSets、handleOutputParameters。
实现类:DefaultResultSetHandler

StatementHandler:prepare、parameterize、batch、update、query。
实现类:CallableStatementHandler/PreparedStatementHandler/SimpleStatementHandler/RoutingStatementHandler

如何自定义插件

1、只需实现Interceptor接口,并指定要拦截的方法签名
2、还需要在配置文件中配置你编写的插件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@Intercepts({
@Signature(
type=Executor.class,method="update",args={ MappedStatement.class,Object.class })
})
public class ExamplePlugin implements Interceptor {
public Object intercept(Invocation invocation) throws Throwable {
//自定义实现
return invocation.proceed();
}
public Object plugin(Object target){
return Plugin.wrap(target,this)
}
public void setProperties(Properties properties){
//传入配置项
String size = properties.getProperty("size");
}
}
<!-- mybatis-config.xml -->
<plugins>
<plugin interceptor="org.mybatis.example.ExamplePlugin">
<!-- 这里的配置项就传入setProperties方法中 -->
<property name="size" value="100">
</plugin>
</plugins>

拦截器实现原理

Mybatis仅可以编写针对ParameterHandler、ResultSetHandler、StatementHandler、Executor这4种接口的插件,Mybatis使用JDK的动态代理,为需要拦截的接口生成代理对象以实现接口方法拦截功能,每当执行这4种接口对象的方法时,就会进入拦截方法,具体就是InvocationHandler的invoke()方法,当然,只会拦截那些你指定需要拦截的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
//拦截器接口,供外部实现,实现该接口就定义了一个插件
public interface Interceptor {
//拦截方法,可以将自定义逻辑写在该方法中
Object intercept(Invocation invocation) throws Throwable;
//包装成插件,一般Plugin.wrap(target,this)就行了
Object plugin(Object target);
//传入自定义配置参数
void setProperties(Properties properties);
}
拦截器上定义的注解
@Intercepts:拦截器注解,包括一个或多个@Signature,拦截的目标类信息
@Signature:拦截的目标类信息,包括type、method、args,一个@Intercepts中可包含多个@Signature

public class Invocation {
private Object target;//目标对象
private Method method;//调用方法
private Object[] args;//方法形参列表
//省略get和set方法
//执行调用,基于动态代理,在Interceptor的intercept方法中一定要调用该方法
public Object proceed() throws InvocationTargetException, IllegalAccessException {
return method.invoke(target, args);
}
}
//动态代理实现
public class Plugin implements InvocationHandler {
private Object target;
private Interceptor interceptor;//拦截器
private Map<Class<?>, Set<Method>> signatureMap;//拦截目标类的目标方法

private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
this.target = target;
this.interceptor = interceptor;
this.signatureMap = signatureMap;
}
//包装目标实例
public static Object wrap(Object target, Interceptor interceptor) {
Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
Class<?> type = target.getClass();
//目标类所有接口是否有signatureMap中定义的Class
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
//如果拦截器中有定义拦截目标类中的方法时,就返回代理实例
if (interfaces.length > 0) {
return Proxy.newProxyInstance(
type.getClassLoader(),
interfaces,
new Plugin(target, interceptor, signatureMap));
}
//没有就返回目标实例
return target;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
//该方法需要拦截
if (methods != null && methods.contains(method)) {
return interceptor.intercept(new Invocation(target, method, args));
}
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
//获取拦截器上的SignatureMap
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
if (interceptsAnnotation == null) {
throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
}
Signature[] sigs = interceptsAnnotation.value();
Map<Class<?>, Set<Method>> signatureMap = new HashMap<Class<?>, Set<Method>>();
for (Signature sig : sigs) {
Set<Method> methods = signatureMap.get(sig.type());//重复定义的只生效一个
if (methods == null) {
methods = new HashSet<Method>();
signatureMap.put(sig.type(), methods);
}
try {
//获取目标类中的指定方法
Method method = sig.type().getMethod(sig.method(), sig.args());
methods.add(method);
} catch (NoSuchMethodException e) {
throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
}
}
return signatureMap;
}

private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
Set<Class<?>> interfaces = new HashSet<Class<?>>();
while (type != null) {//获取type上的所有接口
for (Class<?> c : type.getInterfaces()) {
if (signatureMap.containsKey(c)) {//这里不判断Method,只判断Class<?>
interfaces.add(c);
}
}
type = type.getSuperclass();
}
return interfaces.toArray(new Class<?>[interfaces.size()]);
}
}

在配置文件中定义的过滤器,都保存在Configuration类的interceptorChain中,这个类保存了mybatis的所有配置,interceptorChain类中保存中所有Interceptor集合组成的拦截器链,这个链是如何添加进去的呢?请看源码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
 //XMLConfigBuilder类中解析mybatis-config.xml 核心方法parseConfiguration(XNode root)
pluginElement(root.evalNode("plugins"));//插件配置项

private void pluginElement(XNode parent) throws Exception {
if (parent != null) {
//遍历 plugins的子节点plugin
for (XNode child : parent.getChildren()) {
String interceptor = child.getStringAttribute("interceptor");//获取interceptor属性值
Properties properties = child.getChildrenAsProperties();//获取plugin属性值
//创建拦截器实例,这里interceptor值也可以是typeAlias注册的简名
Interceptor interceptorInstance = (Interceptor) resolveClass(interceptor).newInstance();
//设置属性项
interceptorInstance.setProperties(properties);
//添加到interceptorChain中
configuration.addInterceptor(interceptorInstance);
}
}
}
//Configuration类,添加拦截器
public void addInterceptor(Interceptor interceptor) {
interceptorChain.addInterceptor(interceptor);
}

拦截的哪些接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
//SQL语句处理器
public interface StatementHandler {
//预备工作
Statement prepare(Connection connection, Integer transactionTimeout) throws SQLException;
//参数处理
void parameterize(Statement statement) throws SQLException;
//批量处理
void batch(Statement statement) throws SQLException;
//更新处理
int update(Statement statement) throws SQLException;
//查询处理
<E> List<E> query(Statement statement, ResultHandler resultHandler) throws SQLException;
}
//返回集处理器
public interface ResultSetHandler {
//处理返回结果
<E> List<E> handleResultSets(Statement stmt) throws SQLException;
//处理输出参数
void handleOutputParameters(CallableStatement cs) throws SQLException;
}
//参数处理器
public interface ParameterHandler {

Object getParameterObject();

void setParameters(PreparedStatement ps) throws SQLException;
}

如何拦截这些接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//创建相应Handler时会将所有拦截器通过动态代理方式返回代理Handler
public class Configuration {

//创建ParameterHandler(参数处理器)
public ParameterHandler newParameterHandler(MappedStatement mappedStatement,
Object parameterObject, BoundSql boundSql) {
// 根据指定Lang(默认RawLanguageDriver),创建ParameterHandler,将实际参数传递给JDBC语句
ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(
mappedStatement, parameterObject, boundSql);
//返回代理实例
parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
return parameterHandler;
}

//创建ResultSetHandler(结果处理器)
public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement,
RowBounds rowBounds,ParameterHandler parameterHandler,
ResultHandler resultHandler,BoundSql boundSql) {
//默认使用DefaultResultSetHandler创建ResultSetHandler实例
ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement,
parameterHandler, resultHandler, boundSql, rowBounds);
//返回代理实例
resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
return resultSetHandler;
}

//创建StatementHandler(SQL语句处理器)
public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement,
Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
//默认使用RoutingStatementHandler(路由作用)
//创建指定StatementHandler实例(默认SimpleStatementHandler)
StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement,
parameterObject, rowBounds, resultHandler, boundSql);
//返回代理实例
statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
return statementHandler;
}

//创建Executor(执行器)
public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
//获取executorType,默认是SIMPLE
executorType = executorType == null ? defaultExecutorType : executorType;
//这一行感觉有点多余啊
executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
Executor executor;
if (ExecutorType.BATCH == executorType) { //批量执行
executor = new BatchExecutor(this, transaction);
} else if (ExecutorType.REUSE == executorType) { //重用
executor = new ReuseExecutor(this, transaction);
} else {
executor = new SimpleExecutor(this, transaction);//通用
}
if (cacheEnabled) { //开启缓存
executor = new CachingExecutor(executor);
}
//返回代理实例
executor = (Executor) interceptorChain.pluginAll(executor);
return executor;
}
}
//执行器
public interface Executor {
//更新
int update(MappedStatement ms, Object parameter) throws SQLException;
//查询(先查缓存)
<E> List<E> query(MappedStatement ms, Object parameter, RowBounds rowBounds,
ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws SQLException;
//查询
<E> List<E> query(MappedStatement ms, Object parameter,
RowBounds rowBounds, ResultHandler resultHandler) throws SQLException;
//查询游标
<E> Cursor<E> queryCursor(MappedStatement ms, Object parameter, RowBounds rowBounds)
throws SQLException;
//刷新Statement
List<BatchResult> flushStatements() throws SQLException;
//提交事务
void commit(boolean required) throws SQLException;
//回滚事务
void rollback(boolean required) throws SQLException;
//创建缓存key
CacheKey createCacheKey(MappedStatement ms, Object parameterObject,RowBounds rowBounds,
BoundSql boundSql);
//是否存在key
boolean isCached(MappedStatement ms, CacheKey key);
//清除本地缓存
void clearLocalCache();
//延迟加载
void deferLoad(MappedStatement ms, MetaObject resultObject, String property, CacheKey key,
Class<?> targetType);
//获取事务
Transaction getTransaction();
//关闭连接
void close(boolean forceRollback);
//是否关闭
boolean isClosed();
//设置Executor
void setExecutorWrapper(Executor executor);
}

总结

当然具体实现肯定不止这么多代码,如果需要了解,需要自行看源码。

  • 拦截器实现
    Interceptor接口供插件实现,@Intercepts注解在插件实现上,表示这是一个插件类并配置将要拦截哪些方法,@Signature定义将要拦截的方法信息,如名称/类型/形参列表,Plugin类实现了InvocationHandler接口,是动态代理的具体实现,Invocation类包装了拦截的目标实例,InterceptorChain保存所有拦截器。
  • 如何实现拦截
    创建目标实例,比如A a = new A();
    Interceptor interceptor = new LogInterceptor();//如果拦截a中的save方法
    将A b = (A)interceptor.plugin(a);这里b就是a的代理实例,在调用a中的save方法时,实际将调用interceptor的intercept方法,在该方法中一定要调用Invocation的proceed方法并将返回值返回。

参考资料