Last CXF releases got a websocket transport. It is great but it doesn't enable you to control the transport or can import some dependencies you don't want like Atmosphere.

However there is a workaround based on a ServletContextListener if you want to implement it yourself.

Our hypothesis for this blog post will be:

  • The CXF bus is in CDI context (so injectable), if not replace the injection by another lookup implementation
  • There is a single CXF server - if not replace the iterator().next() by a loop and potentially filter the JAX-RS ones only
  • There is a single JAX-RS Application subclass into the CDI context - if not change the application prefix logic lookup, it can be done using CXF providers
  • The Application path starts with /api/vX - where X is the version - and we want to extract the version to keep it in our websocket redirection

All these hypothesis are not strong constraints from a final applications and could be implemented in a more generic way but since CXF will probably enhance its implementation let's keep this blog post tip simple.

Here what it can look like in term of Java code:

@Slf4j
@Dependent
@WebListener
public class WebSocketBroadcastSetup implements ServletContextListener {
    @Inject
    private Bus bus;

    @Inject
    private Instance<Application> applications;

    @Override
    public void contextInitialized(ServletContextEvent sce) {
        final JAXRSServiceFactoryBean factory = JAXRSServiceFactoryBean.class.cast(bus.getExtension(ServerRegistry.class)
                .getServers().iterator().next() // you could use a loop
                .getEndpoint()
                .get(JAXRSServiceFactoryBean.class.getName()));

        final String appBase = StreamSupport
                .stream(Spliterators.spliteratorUnknownSize(applications.iterator(), Spliterator.IMMUTABLE), false)
                .filter(a -> a.getClass().isAnnotationPresent(ApplicationPath.class))
                .map(a -> a.getClass().getAnnotation(ApplicationPath.class))
                .map(ApplicationPath::value).findFirst()
                .map(s -> !s.startsWith("/") ? "/" + s : s).orElse("/api/v1");
        final String version = appBase.replaceFirst("/api", "");

        // coming in next snippets
    }
}

Now we need to iterate over all endpoints. It can be done through this code:

Stream<OperationResourceInfo> oris = factory.getClassResourceInfo().stream()
    .flatMap(cri -> cri.getMethodDispatcher().getOperationResourceInfos().stream());

From now on we need to think about how to deploy it over websocket protocol. There are multiple options:

  • Deploy one websocket endpoint per JAX-RS endpoint (/api/vX/foo becoming /websocket/vX/foo)
  • Deploy a single endpoint and redirect from another criteria to the JAX-RS runtime
  • Do both previous options at the same time

The most interesting option is the second one since a single websocket connection will give you all endpoints but having the last one doesn't cost a lot and simplify the API discovery so we'll go this way.

To implement that we need to get the websocket "container" API, this can be done from the servlet context:

ServerContainer container = ServerContainer.class.cast(
    sce.getServletContext().getAttribute(ServerContainer.class.getName()));

Then, in terms of actual implementation we will fake a HTTP request for each websocket message and redirect the output to the websocket output stream.

In terms of protocol we will use a formatting close to STOMP (but not exactly STOMP):

  • request will start by SEND, then contain one line per header, the virtual header destination which is the actual endpoint, and finally the payload finished by ^@
  • response will use the same formatting than requests but instead of SEND, they will start with MESSAGE

The implementation is based on CXF servlet integration, a.k.a ServletController which takes a request and response parameters. Therefore we "just" need to convert each JAX-RS endpoint to a websocket endpoint invoking this controller.

Here is what the websocket endpoint can look like:

@Data
public class JAXRSEndpoint extends Endpoint {
    private final String appBase;
    private final ServletController controller;
    private final ServletContext context;
    private final String method;
    private final String defaultUri;
    private final Map<String, List<String>> baseHeaders;

    @Override
    public void onOpen(final Session session, final EndpointConfig endpointConfig) {
        log.debug("Opened session {}", session.getId());
        session.addMessageHandler(InputStream.class, message -> {
            final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
            headers.putAll(baseHeaders);

            final StringBuilder buffer = new StringBuilder(128);
            try { // read headers from the message
                if (!"SEND".equalsIgnoreCase(readLine(buffer, message))) {
                    throw new IllegalArgumentException("not a message");
                }

                String line;
                int del;
                while ((line = readLine(buffer, message)) != null) {
                    final boolean done = line.endsWith(EOM);
                    if (done) {
                        line = line.substring(0, line.length() - EOM.length());
                    }
                    if (!line.isEmpty()) {
                        del = line.indexOf(':');
                        if (del < 0) {
                            headers.put(line.trim(), emptyList());
                        } else {
                            headers.put(line.substring(0, del).trim(), singletonList(line.substring(del + 1).trim()));
                        }
                    }
                    if (done) {
                        break;
                    }
                }
            } catch (final IOException ioe) {
                throw new IllegalStateException(ioe);
            }

            final List<String> uris = headers.get("destination");
            final String uri;
            if (uris == null || uris.isEmpty()) {
                uri = defaultUri;
            } else {
                uri = uris.iterator().next();
            }

            final String queryString;
            final String path;
            final int query = uri.indexOf('?');
            if (query > 0) {
                queryString = uri.substring(query + 1);
                path = uri.substring(0, query);
            } else {
                queryString = null;
                path = uri;
            }

            try {
                final WebSocketRequest request = new WebSocketRequest(method, headers, path, appBase + path, appBase,
                        queryString, 8080, context, new WebSocketInputStream(message), session);
                controller.invoke(request, new WebSocketResponse(session));
            } catch (final ServletException e) {
                throw new IllegalArgumentException(e);
            }
        });
    }

    @Override
    public void onClose(final Session session, final CloseReason closeReason) {
        log.debug("Closed session {}", session.getId());
    }

    @Override
    public void onError(final Session session, final Throwable throwable) {
        log.warn("Error for session {}", session.getId(), throwable);
    }

    private static String readLine(final StringBuilder buffer, final InputStream in) throws IOException {
        int c;
        while ((c = in.read()) != -1) {
            if (c == '\n') {
                break;
            } else if (c != '\r') {
                buffer.append((char) c);
            }
        }

        if (buffer.length() == 0) {
            return null;
        }
        final String string = buffer.toString();
        buffer.setLength(0);
        return string;
    }
}

As you can see it is mainly reading the message respecting its formatting and just redirecting it to the servlet controller into the message handler. The WebSocketRequest is a request mock without much logic so I will not detail it here but you will notice the input stream is passed to the request. This one is important because it will read the websocket stream until the end of the message (^@):

public class WebSocketInputStream extends ServletInputStream {
    private final InputStream delegate;
    private boolean finished;
    private int previous = Integer.MAX_VALUE;

    public WebSocketInputStream(final InputStream delegate) {
        this.delegate = delegate;
    }

    @Override
    public boolean isFinished() {
        return finished;
    }

    @Override
    public boolean isReady() {
        return true;
    }

    @Override
    public void setReadListener(final ReadListener readListener) {
        // no-op
    }

    @Override
    public int read() throws IOException {
        if (finished) {
            return -1;
        }
        if (previous != Integer.MAX_VALUE) {
            previous = Integer.MAX_VALUE;
            return previous;
        }
        final int read = delegate.read();
        if (read == '^') {
            previous = delegate.read();
            if (previous == '@') {
                finished = true;
                return -1;
            }
        }
        if (read < 0) {
            finished = true;
        }
        return read;
    }
}

The response is alsmot as trivial as the request (headers are a case insensitive map keyed by a string and valued by a list of string etc...) except the output stream handling. Here is a simplified response implementation (the easy methods were removed for brievity):

public class WebSocketResponse implements HttpServletResponse {

    private final Session session;
    private String responseString = "OK";
    private int code = HttpServletResponse.SC_OK;
    private final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
    private transient PrintWriter writer;
    private transient ServletByteArrayOutputStream sosi;
    private boolean commited = false;
    private String encoding = "UTF-8";
    private Locale locale = Locale.getDefault();

    public WebSocketResponse(final Session session) {
        this.session = session;
    }

    @Override
    public void sendRedirect(final String path) throws IOException {
        if (commited) {
            throw new IllegalStateException("response already committed");
        }
        resetBuffer();

        try {
            setStatus(SC_FOUND);

            setHeader("Location", toEncoded(path));
        } catch (final IllegalArgumentException e) {
            setStatus(SC_NOT_FOUND);
        }
    }

    @Override
    public ServletOutputStream getOutputStream() {
        return sosi == null ? (sosi = createOutputStream()) : sosi;
    }

    @Override
    public void reset() {
        createOutputStream();
    }

    private ServletByteArrayOutputStream createOutputStream() {
        return sosi = new ServletByteArrayOutputStream(session, () -> {
            final StringBuilder top = new StringBuilder("MESSAGE\r\n");
            top.append("status: ").append(getStatus()).append("\r\n");
            headers.forEach(
                    (k, v) -> top.append(k).append(": ").append(v.stream().collect(Collectors.joining(","))).append("\r\n"));
            top.append("\r\n");// empty line, means the next bytes are the payload
            return top.toString();
        });
    }

    public void setCode(final int code) {
        this.code = code;
        commited = true;
    }

    @Override
    public void resetBuffer() {
        sosi.outputStream.reset();
    }
}

What is interesting is to see that the output stream is used as a delimiter to flush the headers accumulated in memory - with the little trick to add the status as header to simplify the client handling. Note that in the real STOMP protocol it would use an ERROR message instead of a MESSAGE one.

The ServletByteArrayOutputStream is the one doing the link with the websocket output stream. We use an in memory buffer and only send the data to the client when a flush is called - and we use the hook we implemented just before to flush headers the first time. Here is one implementation:

public class ServletByteArrayOutputStream extends ServletOutputStream {
    private static final byte[] EOM_BYTES = EOM.getBytes(StandardCharsets.UTF_8);
    private static final int BUFFER_SIZE = 1024 * 8;

    private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
    private final Session session;
    private final Supplier<String> preWrite;
    private boolean closed;
    private boolean headerWritten;

    public ServletByteArrayOutputStream(final Session session, final Supplier<String> preWrite) {
        this.session = session;
        this.preWrite = preWrite;
    }

    @Override
    public boolean isReady() {
        return true;
    }

    @Override
    public void setWriteListener(final WriteListener listener) {
        // no-op
    }

    @Override
    public void write(final int b) throws IOException {
        outputStream.write(b);
    }

    @Override
    public void write(final byte[] b, final int off, final int len) {
        outputStream.write(b, off, len);
    }

    public void writeTo(final OutputStream out) throws IOException {
        outputStream.writeTo(out);
    }

    public void reset() {
        outputStream.reset();
    }

    @Override
    public void flush() throws IOException {
        if (!session.isOpen()) {
            return;
        }
        if (outputStream.size() >= BUFFER_SIZE) {
            doFlush();
        }
    }

    @Override
    public void close() throws IOException {
        if (closed) {
            return;
        }

        outputStream.write(EOM_BYTES);
        doFlush();
        closed = true;
    }

    private void doFlush() throws IOException {
        final RemoteEndpoint.Basic basicRemote = session.getBasicRemote();

        final byte[] array = outputStream.toByteArray();
        final boolean written = array.length > 0 || !headerWritten;

        if (!headerWritten) {
            final String headers = preWrite.get();
            basicRemote.sendBinary(ByteBuffer.wrap(headers.getBytes(StandardCharsets.UTF_8)));
            headerWritten = true;
        }

        if (array.length > 0) {
            outputStream.reset();
            basicRemote.sendBinary(ByteBuffer.wrap(array));
        }

        if (written && basicRemote.getBatchingAllowed()) {
            basicRemote.flushBatch();
        }
    }
}

This should be enough for all synchronous endpoints but if you use some asynchrnous endpoints (@Suspended), you need to add a websocket continuation since your servlet "mock" will likely not go as far as implementing the AsyncContext logic and its state machine.

To do that we need to implement our own destination. We will just wrap the servlet one. First we implement a custom registry wrapping the servlet destinations:

public class WebSocketRegistry implements DestinationRegistry {
    private final DestinationRegistry delegate;
    private ServletController controller;

    public WebSocketRegistry(final DestinationRegistry registry) {
        this.delegate = registry;
    }

    @Override
    public void addDestination(final AbstractHTTPDestination destination) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void removeDestination(final String path) {
        throw new UnsupportedOperationException();
    }

    @Override
    public AbstractHTTPDestination getDestinationForPath(final String path) {
        return wrap(delegate.getDestinationForPath(path));
    }

    @Override
    public AbstractHTTPDestination getDestinationForPath(final String path, final boolean tryDecoding) {
        return wrap(delegate.getDestinationForPath(path, tryDecoding));
    }

    @Override
    public AbstractHTTPDestination checkRestfulRequest(final String address) {
        return wrap(delegate.checkRestfulRequest(address));
    }

    @Override
    public Collection<AbstractHTTPDestination> getDestinations() {
        return delegate.getDestinations();
    }

    @Override
    public AbstractDestination[] getSortedDestinations() {
        return delegate.getSortedDestinations();
    }

    @Override
    public Set<String> getDestinationsPaths() {
        return delegate.getDestinationsPaths();
    }

    private AbstractHTTPDestination wrap(final AbstractHTTPDestination destination) {
        try {
            return destination == null ? null : new WebSocketDestination(destination, this);
        } catch (final IOException e) {
            throw new IllegalStateException(e);
        }
    }
}

And of course our destination implementation:

public class WebSocketDestination extends AbstractHTTPDestination {

    private static final Logger LOG = LogUtils.getL7dLogger(ServletDestination.class);

    private final AbstractHTTPDestination delegate;

    public WebSocketDestination(final AbstractHTTPDestination delegate, final WebSocketRegistry registry)
            throws IOException {
        super(delegate.getBus(), registry, new EndpointInfo(), delegate.getPath(), false);
        this.delegate = delegate;
        this.cproviderFactory = new WebSocketContinuationFactory(registry);
    }

    @Override
    public void shutdown() {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setMessageObserver(final MessageObserver observer) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected Logger getLogger() {
        return LOG;
    }

    @Override
    public void invoke(final ServletConfig config, final ServletContext context, final HttpServletRequest req,
            final HttpServletResponse resp) throws IOException {
        // eager create the message to ensure we set our continuation for @Suspended
        Message inMessage = retrieveFromContinuation(req);
        if (inMessage == null) {
            inMessage = new MessageImpl();

            final ExchangeImpl exchange = new ExchangeImpl();
            exchange.setInMessage(inMessage);
            setupMessage(inMessage, config, context, req, resp);

            exchange.setSession(new HTTPSession(req));
            MessageImpl.class.cast(inMessage).setDestination(this);
        }

        delegate.invoke(config, context, req, resp);
    }

    // delegate all other methods to delegate field
}

The only trick here is to set the cproviderFactory into the constructor. The continuation factory is really straight forward:

public class WebSocketContinuationFactory implements ContinuationProviderFactory {
    private static final String KEY = WebSocketContinuationFactory.class.getName();

    private final WebSocketRegistry registry;

    public WebSocketContinuationFactory(final WebSocketRegistry registry) {
        this.registry = registry;
    }

    @Override
    public ContinuationProvider createContinuationProvider(final Message inMessage, final HttpServletRequest req,
            final HttpServletResponse resp) {
        return new WebSocketContinuation(inMessage, req, resp, registry);
    }

    @Override
    public Message retrieveFromContinuation(final HttpServletRequest req) {
        return Message.class.cast(req.getAttribute(KEY));
    }
}

And finally our continuation implementation is really mapped on the servlet flavor since in the websocket environment we don't need anything else:

public class WebSocketContinuation implements ContinuationProvider, Continuation {
    private final Message message;
    private final HttpServletRequest request;
    private final HttpServletResponse response;
    private final WebSocketRegistry registry;
    private final ContinuationCallback callback;
    private Object object;
    private boolean resumed;
    private boolean pending;
    private boolean isNew;

    public WebSocketContinuation(final Message message, final HttpServletRequest request, final HttpServletResponse response,
            final WebSocketRegistry registry) {
        this.message = message;
        this.request = request;
        this.response = response;
        this.registry = registry;
        this.request.setAttribute(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE, message.getExchange().getInMessage());
        this.callback = message.getExchange().get(ContinuationCallback.class);
    }

    @Override
    public Continuation getContinuation() {
        return this;
    }

    @Override
    public void complete() {
        message.getExchange().getInMessage().remove(AbstractHTTPDestination.CXF_CONTINUATION_MESSAGE);
        if (callback != null) {
            final Exception ex = message.getExchange().get(Exception.class);
            if (ex == null) {
                callback.onComplete();
            } else {
                callback.onError(ex);
            }
        }
        try {
            response.getWriter().close();
        } catch (final IOException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override
    public boolean suspend(final long timeout) {
        isNew = false;
        resumed = false;
        pending = true;
        message.getExchange().getInMessage().getInterceptorChain().suspend();
        return true;
    }

    @Override
    public void resume() {
        resumed = true;
        try {
            registry.controller.invoke(request, response);
        } catch (final ServletException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override
    public void reset() {
        pending = false;
        resumed = false;
        isNew = false;
        object = null;
    }

    @Override
    public boolean isNew() {
        return isNew;
    }

    @Override
    public boolean isPending() {
        return pending;
    }

    @Override
    public boolean isResumed() {
        return resumed;
    }

    @Override
    public Object getObject() {
        return object;
    }

    @Override
    public void setObject(final Object o) {
        object = o;
    }
}

Finally to wire it to the websocket deployment - and complete our ServletContextListener - we convert each endpoint to a websocket deployment using our JAXRSEndpoint:

oris.map(ori -> {
    final String uri = ori.getClassResourceInfo().getURITemplate().getValue() + ori.getURITemplate().getValue();
    return ServerEndpointConfig.Builder
            .create(
               Endpoint.class,
               "/websocket" + version + "/" + String.valueOf(ori.getHttpMethod()).toLowerCase(ENGLISH) + uri)
            .configurator(new ServerEndpointConfig.Configurator() {

                @Override
                public <T> T getEndpointInstance(final Class<T> clazz) throws InstantiationException {
                    final Map<String, List<String>> headers = new HashMap<>();
                    if (!ori.getProduceTypes().isEmpty()) {
                        headers.put(HttpHeaders.CONTENT_TYPE,
                                singletonList(ori.getProduceTypes().iterator().next().toString()));
                    }
                    if (!ori.getConsumeTypes().isEmpty()) {
                        headers.put(HttpHeaders.ACCEPT,
                                singletonList(ori.getConsumeTypes().iterator().next().toString()));
                    }
                    return (T) new JAXRSEndpoint(appBase, controller, servletContext, ori.getHttpMethod(), uri, headers);
                }
            }).build();
}).sorted(Comparator.comparing(ServerEndpointConfig::getPath))
.peek(e -> log.info("Deploying WebSocket(path={})", e.getPath())).forEach(config -> {
    try {
        container.addEndpoint(config);
    } catch (final DeploymentException e) {
        throw new IllegalStateException(e);
    }
})

This code is pretty simple once you entered into it:

  • each endpoint is mapped to a websocket endpoint using this uri pattern: /websocket/vX/<http method>/<endpoint path>
  • each endpoint uses JAXRSEndpoint
  • the consumes/produces headers are forced (optional)
  • finally the endpoint config is deployed using addEndpoint method of the ServerContainer

Tip: the peek call logs the path to list all the deployed endpoints, it is quite appreciable after some times or if you have some issues doing a test/client.

 

This code is not the most sexy you can find out there but it works, allows to use JAX-RS programming model but websocket communications. It however enables you to use the message formatting you want (we use a text based format but you could use json/xml/binary/custom formats) and to reuse the same connection for any endpoint since we rely on the destination header of the request to route the message in the servlet controller in our message handler.

 

The nice thing about such a solution is that you get a super efficient and performant implementation with a very clean and strong programming model (JAX-RS which uses the command based pattern).

 

 

From the same author:

In the same category: