Transport Layer Security is something I've been working with since the very first publication of Cross Enterprise Document Sharing (XDS). For that very first Connectathon, we had to demonstrate that we could secure the XDS communications channels, and I struggled with that part. So much so that I wrote what might be considered my first blog post, the ATNA FAQ.
Since then, I've become somewhat expert in implementing, diagnosing, and resolving issues with TLS implementations in Java. One of my most recent challenges involved having to diagnose problems with a protocol where the usual methods in the ATNA FAQ don't work, because the underlying implementation is on Bouncy Castle FIPS, and that doesn't respond to the usual Java -Djavax.net.debug options I first wrote about 20 years ago. There's a good reason for that, but it makes debugging protocol implementations a bit of a beast when you cannot see what is happening. Sure, you can wireshark it, but that doesn't always work when you aren't in control (as I wasn't) the responding service.
The solution is rather ugly, and it involves interceptors. What I wound up doing was getting a hold of the client socket factory, wrapping it with my own implementation, intercepting calls to the Socket creation code, wrapping them in interceptors which catch calls get underlying Input and Output streams, and wrapping each of those in an Interceptor which filters the input or output stream to duplicate the output onto my console.This is brutually tedious code, because it involves overriding EVERY implemented public method of these classes, calling the original intercepted object's method and returning the result (or slightly modifying it in a very few cases). Fortunately, modern IDEs will write MOST of this code for you and search and replace will finish it (except for the special cases). I won't reproduce every line here, but I'll explain the technique.
First, understand that I'm dealing with client calls to a server, and these are built on HttpsURLConnection. For those of you who've been programming in Java, you may already understand that this is not an interface, but rather an abstract class, and the finicky details are found underneath it in sun.net.www.protocol.http.HttpsURLConnectionImpl (BTW: Good luck finding the source for that).
When you have your URL Connection, you will need to wrap the SSL Socket Factory it is using with a new socket factory that intercepts calls to the existing one.
HttpsURLConnection conx = (HttpsURLConnection)url.openConnection();
if (debug) {
conx.setSSLSocketFactory(new MySSLSocketFactory(conx.getSSLSocketFactory()));
}
Next you are going to create this new SSLSocketFactory:
public class MySSLSocketFactory {
private final SSLSocketFactory base;
MySSLSocketFactory(SSLSocketFactory base) {
this.base = base;
}
...
}
Inside this class you will override every method implemented in the class (Eclipse's Source | Override Implement Methods menu makes this easy).
Each override follows one of two patterns:
@Override
public void methodX(Type0 arg0, Type1 arg1) {
base.methodX(arg0, arg1);
}
Or for methods returning a value:
@Override
public TypeR methodX(Type0 arg0, Type1 arg1) {
return base.methodX(arg0, arg1);
}
For the createSocket methods, you will vary this slightly:
@Override
public Socket createSocket(Type0 arg0, Type1 arg1) {
return new MySocketWrapper(base.methodX(arg0, arg1));
}
public class MySSLSocketWrapper extends SSLSocket {
private final SSLSocket base;
public MySSLSocketWrapper(SSLSocket base) {
this.base = base;
}
@Override
public InputStream getInputStream() throws IOException {
return new MyInterceptingInputStream(base.getInputStream());
}
@Override
public OutputStream getOutputStream() throws IOException {
return new MyInterceptingOutputStream(base.getOutputStream());
}
// ... Add other Overrides here just calling the base method of the same name
}
private static final byte[] BLUE = { '\033', '[', '3', '4', 'm' };
private static final byte[] NORMAL = { '\033', '[', '0', 'm' };
@Override
public int read(byte[] b, int off, int len) throws IOException {
int val = in.read(b, off, len);
if (val > 0) {
try {
out.write(BLUE);
out.write(b, off, val);
out.write(NORMAL);
} catch (IOException ex) {
// Swallow it.
}
}
return val;
}
No comments:
Post a Comment