src/rabbit_stomp.erl
author Tony Garnock-Jones <tonyg@lshift.net>
Wed Apr 30 16:48:11 2008 +0100 (2008-04-30)
changeset 26 b37faa511709
parent 25 36fad000db32
child 28 392d8cc8449c
child 37 a1099f0d77e8
permissions -rw-r--r--
Avoid losing messages when the socket closes abruptly by calling
rabbit_channel:shutdown/1, which nicely processes all the pending work
before notifying us of channel closure.
     1 %%   The contents of this file are subject to the Mozilla Public License
     2 %%   Version 1.1 (the "License"); you may not use this file except in
     3 %%   compliance with the License. You may obtain a copy of the License at
     4 %%   http://www.mozilla.org/MPL/
     5 %%
     6 %%   Software distributed under the License is distributed on an "AS IS"
     7 %%   basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
     8 %%   License for the specific language governing rights and limitations
     9 %%   under the License.
    10 %%
    11 %%   The Original Code is RabbitMQ.
    12 %%
    13 %%   The Initial Developers of the Original Code are LShift Ltd.,
    14 %%   Cohesive Financial Technologies LLC., and Rabbit Technologies Ltd.
    15 %%
    16 %%   Portions created by LShift Ltd., Cohesive Financial
    17 %%   Technologies LLC., and Rabbit Technologies Ltd. are Copyright (C) 
    18 %%   2007 LShift Ltd., Cohesive Financial Technologies LLC., and Rabbit 
    19 %%   Technologies Ltd.; 
    20 %%
    21 %%   All Rights Reserved.
    22 %%
    23 %%   Contributor(s): ______________________________________.
    24 %%
    25 
    26 %% rabbit_stomp implements STOMP messaging semantics, as per protocol
    27 %% "version 1.0", at http://stomp.codehaus.org/Protocol
    28 
    29 -module(rabbit_stomp).
    30 
    31 -export([kickstart/0,
    32 	 start/1,
    33 	 listener_started/2, listener_stopped/2, start_client/1,
    34 	 start_link/0, init/1, mainloop/1]).
    35 
    36 -include("rabbit.hrl").
    37 -include("rabbit_framing.hrl").
    38 -include("stomp_frame.hrl").
    39 
    40 -record(state, {socket, session_id, channel, parse_state, ticket}).
    41 
    42 kickstart() ->
    43     {ok, StompListeners} = application:get_env(stomp_listeners),
    44     ok = start(StompListeners).
    45 
    46 start(Listeners) ->
    47     {ok,_} = supervisor:start_child(
    48                rabbit_sup,
    49                {rabbit_stomp_client_sup,
    50                 {tcp_client_sup, start_link,
    51                  [{local, rabbit_stomp_client_sup},
    52                   {rabbit_stomp,start_link,[]}]},
    53                 transient, infinity, supervisor, [tcp_client_sup]}),
    54     ok = start_listeners(Listeners),
    55     ok.
    56 
    57 start_listeners([]) ->
    58     ok;
    59 start_listeners([{Host, Port} | More]) ->
    60     {IPAddress, Name} = rabbit_networking:check_tcp_listener_address(rabbit_stomp_listener_sup,
    61 								     Host,
    62 								     Port),
    63     {ok,_} = supervisor:start_child(
    64                rabbit_sup,
    65                {Name,
    66                 {tcp_listener_sup, start_link,
    67 		 [IPAddress, Port,
    68 		  [{packet, raw},
    69 		   {reuseaddr, true}],
    70 		  {?MODULE, listener_started, []},
    71 		  {?MODULE, listener_stopped, []},
    72 		  {?MODULE, start_client, []}]},
    73                 transient, infinity, supervisor, [tcp_listener_sup]}),
    74     start_listeners(More).
    75 
    76 listener_started(_IPAddress, _Port) ->
    77     ok.
    78 
    79 listener_stopped(_IPAddress, _Port) ->
    80     ok.
    81 
    82 start_client(Sock) ->
    83     {ok, Child} = supervisor:start_child(rabbit_stomp_client_sup, []),
    84     ok = gen_tcp:controlling_process(Sock, Child),
    85     Child ! {go, Sock},
    86     Child.
    87 
    88 start_link() ->
    89     {ok, proc_lib:spawn_link(?MODULE, init, [self()])}.
    90 
    91 init(_Parent) ->
    92     receive
    93         {go, Sock} ->
    94 	    ok = inet:setopts(Sock, [{active, true}]),
    95 	    process_flag(trap_exit, true),
    96 	    ?MODULE:mainloop(#state{socket = Sock,
    97 				    channel = none,
    98 				    parse_state = stomp_frame:initial_state()})
    99     end.
   100 
   101 mainloop(State) ->
   102     receive
   103 	E = {'EXIT', _Pid, _Reason} ->
   104 	    handle_exit(E, State);
   105 	{tcp, _Sock, Bytes} ->
   106 	    process_received_bytes(Bytes, State);
   107 	{tcp_closed, _Sock} ->
   108 	    case State#state.channel of
   109 		none ->
   110 		    done;
   111 		ChPid ->
   112 		    rabbit_channel:shutdown(ChPid),
   113 		    ?MODULE:mainloop(State)
   114 	    end;
   115 	{send_command, Command} ->
   116 	    ?MODULE:mainloop(send_reply(Command, State));
   117 	{send_command_and_notify, QPid, TxPid, Method, Content} ->
   118 	    State1 = send_reply(Method, Content, State),
   119 	    rabbit_amqqueue:notify_sent(QPid, TxPid),
   120 	    ?MODULE:mainloop(State1);
   121 	{send_command_and_shutdown, Command} ->
   122 	    send_reply(Command, State),
   123 	    done;
   124 	shutdown ->
   125 	    done;
   126 	Data ->
   127 	    error_logger:error_msg("Internal error: unknown STOMP Data: ~p~n", [Data]),
   128 	    ?MODULE:mainloop(State)
   129     end.
   130 
   131 simple_method_sync_rpc(Method, State0) ->
   132     State = send_method(Method, State0),
   133     receive
   134 	E = {'EXIT', _Pid, _Reason} ->
   135 	    handle_exit(E, State);
   136 	{send_command, Reply} ->
   137 	    {ok, Reply, State}
   138     end.
   139 
   140 handle_exit({'EXIT', _Pid, {amqp, Code, Method}}, State) ->
   141     explain_amqp_death(Code, Method, State),
   142     done;
   143 handle_exit({'EXIT', Pid, Reason}, State) ->
   144     send_error("Error", "Process ~p exited with reason:~n~p~n", [Pid, Reason], State),
   145     done.
   146 
   147 process_received_bytes([], State) ->
   148     ?MODULE:mainloop(State);
   149 process_received_bytes(Bytes, State = #state{parse_state = ParseState}) ->
   150     case stomp_frame:parse(Bytes, ParseState) of
   151 	{more, ParseState1} ->
   152 	    ?MODULE:mainloop(State#state{parse_state = ParseState1});
   153 	{ok, Frame = #stomp_frame{command = Command}, Rest} ->
   154 	    %% io:format("Frame: ~p~n", [Frame]),
   155 	    case catch process_frame(Command, Frame,
   156 				     State#state{parse_state = stomp_frame:initial_state()}) of
   157 		{'EXIT', {amqp, Code, Method}} ->
   158 		    explain_amqp_death(Code, Method, State),
   159 		    done;
   160 		{'EXIT', Reason} ->
   161 		    send_error("Processing error", "~p~n", [Reason], State),
   162 		    done;
   163 		{ok, NewState} ->
   164 		    process_received_bytes(Rest, NewState);
   165 		stop ->
   166 		    done
   167 	    end;
   168 	{error, Reason} ->
   169 	    send_error("Invalid frame", "Could not parse frame: ~p~n", [Reason], State),
   170 	    done
   171     end.
   172 
   173 explain_amqp_death(Code, Method, State) ->
   174     send_error(atom_to_list(Code), "Method was ~p~n", [Method], State).
   175 
   176 send_reply(#'channel.close_ok'{}, State) ->
   177     State;
   178 send_reply(Command, State) ->
   179     error_logger:error_msg("STOMP Reply command unhandled: ~p~n", [Command]),
   180     State.
   181 
   182 send_reply(#'basic.deliver'{consumer_tag = ConsumerTag,
   183 			    delivery_tag = DeliveryTag,
   184 			    exchange = Exchange,
   185 			    routing_key = RoutingKey},
   186 	   #content{properties = #'P_basic'{headers = Headers},
   187 		    payload_fragments_rev = BodyFragmentsRev},
   188 	   State = #state{session_id = SessionId}) ->
   189     send_frame("MESSAGE",
   190 	       [{"destination", binary_to_list(RoutingKey)},
   191 		{"exchange", binary_to_list(Exchange)},
   192 		{"message-id", SessionId ++ "_" ++ integer_to_list(DeliveryTag)}]
   193 	       ++ case ConsumerTag of
   194 		      <<"Q_", _/binary>> ->
   195 			  [];
   196 		      <<"T_", Id/binary>> ->
   197 			  [{"subscription", binary_to_list(Id)}]
   198 		  end
   199 	       ++ adhoc_convert_headers(case Headers of
   200 					    undefined -> [];
   201 					    _ -> Headers
   202 					end),
   203 	       lists:concat(lists:reverse(lists:map(fun erlang:binary_to_list/1,
   204 						    BodyFragmentsRev))),
   205 	       State);
   206 send_reply(Command, Content, State) ->
   207     error_logger:error_msg("STOMP Reply command unhandled: ~p~n~p~n", [Command, Content]),
   208     State.
   209 
   210 adhoc_convert_headers([]) ->
   211     [];
   212 adhoc_convert_headers([{K, longstr, V} | Rest]) ->
   213     [{"X-" ++ binary_to_list(K), binary_to_list(V)} | adhoc_convert_headers(Rest)];
   214 adhoc_convert_headers([{K, signedint, V} | Rest]) ->
   215     [{"X-" ++ binary_to_list(K), integer_to_list(V)} | adhoc_convert_headers(Rest)];
   216 adhoc_convert_headers([_ | Rest]) ->
   217     adhoc_convert_headers(Rest).
   218 
   219 send_frame(Frame, State = #state{socket = Sock}) ->
   220     %% io:format("Sending ~p~n", [Frame]),
   221     ok = gen_tcp:send(Sock, stomp_frame:serialize(Frame)),
   222     State.
   223 
   224 send_frame(Command, Headers, Body, State) ->
   225     send_frame(#stomp_frame{command = Command,
   226 			    headers = Headers,
   227 			    body = Body},
   228 	       State).
   229 
   230 send_error(Message, Detail, State) ->
   231     error_logger:error_msg("STOMP error frame sent:~nMessage: ~p~nDetail: ~p~n",
   232 			   [Message, Detail]),
   233     send_frame("ERROR", [{"message", Message}], Detail, State).
   234 
   235 send_error(Message, Format, Args, State) ->
   236     send_error(Message, lists:flatten(io_lib:format(Format, Args)), State).
   237 
   238 process_frame("CONNECT", Frame, State = #state{channel = none}) ->
   239     {ok, DefaultVHost} = application:get_env(default_vhost),
   240     do_login(stomp_frame:header(Frame, "login"),
   241 	     stomp_frame:header(Frame, "passcode"),
   242 	     stomp_frame:header(Frame, "virtual-host", binary_to_list(DefaultVHost)),
   243 	     stomp_frame:header(Frame, "realm", "/data"),
   244 	     State);
   245 process_frame("DISCONNECT", _Frame, _State = #state{channel = none}) ->
   246     stop;
   247 process_frame(_Command, _Frame, State = #state{channel = none}) ->
   248     {ok, send_error("Illegal command",
   249 		    "You must log in using CONNECT first\n",
   250 		    State)};
   251 process_frame(Command, Frame, State) ->
   252     case process_command(Command, Frame, State) of
   253 	{ok, State1} ->
   254 	    {ok, case stomp_frame:header(Frame, "receipt") of
   255 		     {ok, Id} ->
   256 			 send_frame("RECEIPT", [{"receipt-id", Id}], "", State1);
   257 		     not_found ->
   258 			 State1
   259 		 end};
   260 	stop ->
   261 	    stop
   262     end.
   263 
   264 send_method(Method, State = #state{channel = ChPid}) ->
   265     ok = rabbit_channel:do(ChPid, Method),
   266     State.
   267 
   268 send_method(Method, Properties, Body, State = #state{channel = ChPid}) ->
   269     ok = rabbit_channel:do(ChPid,
   270 			   Method,
   271 			   #content{class_id = 60, %% basic
   272 				    properties = Properties,
   273 				    properties_bin = none,
   274 				    payload_fragments_rev = [list_to_binary(Body)]}),
   275     State.
   276 
   277 do_login({ok, Login}, {ok, Passcode}, VirtualHost, Realm, State) ->
   278     U = rabbit_access_control:user_pass_login(list_to_binary(Login),
   279 					      list_to_binary(Passcode)),
   280     ok = rabbit_access_control:check_vhost_access(U, list_to_binary(VirtualHost)),
   281     ChPid = 
   282 	rabbit_channel:start_link(self(), self(), U#user.username, list_to_binary(VirtualHost)),
   283     {ok, #'channel.open_ok'{}, State1} =
   284 	simple_method_sync_rpc(#'channel.open'{out_of_band = <<"">>},
   285 			       State#state{channel = ChPid}),
   286     SessionId = rabbit_misc:string_guid("session"),
   287     {ok, #'access.request_ok'{ticket = Ticket}, State2} =
   288 	simple_method_sync_rpc(#'access.request'{realm = list_to_binary(Realm),
   289 						 exclusive = false,
   290 						 passive = true,
   291 						 active = true,
   292 						 write = true,
   293 						 read = true},
   294 			       send_frame("CONNECTED",
   295 					  [{"session", SessionId}],
   296 					  "",
   297 					  State1#state{session_id = SessionId})),
   298     {ok, State2#state{ticket = Ticket}};
   299 do_login(_, _, _, _, State) ->
   300     {ok, send_error("Bad CONNECT", "Missing login or passcode header(s)\n", State)}.
   301 
   302 user_header_key("X-" ++ UserKey) -> UserKey;
   303 user_header_key(_) -> false.
   304 
   305 make_string_table(_KeyFilter, []) -> [];
   306 make_string_table(KeyFilter, [{K, V} | Rest]) ->
   307     case KeyFilter(K) of
   308 	false ->
   309 	    make_string_table(KeyFilter, Rest);
   310 	NewK ->
   311 	    [{list_to_binary(NewK), longstr, list_to_binary(V)}
   312 	     | make_string_table(KeyFilter, Rest)]
   313     end.
   314 
   315 transactional(Frame) ->
   316     case stomp_frame:header(Frame, "transaction") of
   317 	{ok, Transaction} ->
   318 	    {yes, Transaction};
   319 	not_found ->
   320 	    no
   321     end.
   322 
   323 transactional_action(Frame, Name, Fun, State) ->
   324     case transactional(Frame) of
   325 	{yes, Transaction} ->
   326 	    Fun(Transaction, State);
   327 	no ->
   328 	    {ok, send_error("Missing transaction",
   329 			    Name ++ " must include a 'transaction' header\n",
   330 			    State)}
   331     end.
   332 
   333 with_transaction(Transaction, State, Fun) ->
   334     case get({transaction, Transaction}) of
   335 	undefined ->
   336 	    {ok, send_error("Bad transaction",
   337 			    "Invalid transaction identifier: ~p~n", [Transaction],
   338 			    State)};
   339 	Actions ->
   340 	    Fun(Actions, State)
   341     end.
   342 
   343 begin_transaction(Transaction, State) ->
   344     put({transaction, Transaction}, []),
   345     {ok, State}.
   346 
   347 extend_transaction(Transaction, Action, State0) ->
   348     with_transaction(Transaction, State0,
   349 		     fun (Actions, State) ->
   350 			     put({transaction, Transaction}, [Action | Actions]),
   351 			     {ok, State}
   352 		     end).
   353 
   354 commit_transaction(Transaction, State0) ->
   355     with_transaction(Transaction, State0,
   356 		     fun (Actions, State) ->
   357 			     FinalState = lists:foldr(fun perform_transaction_action/2,
   358 						      State,
   359 						      Actions),
   360 			     erase({transaction, Transaction}),
   361 			     {ok, FinalState}
   362 		     end).
   363 
   364 abort_transaction(Transaction, State0) ->
   365     with_transaction(Transaction, State0,
   366 		     fun (_Actions, State) ->
   367 			     erase({transaction, Transaction}),
   368 			     {ok, State}
   369 		     end).
   370 
   371 perform_transaction_action({Method}, State) ->
   372     send_method(Method, State);
   373 perform_transaction_action({Method, Props, Body}, State) ->
   374     send_method(Method, Props, Body, State).
   375 
   376 process_command("BEGIN", Frame, State) ->
   377     transactional_action(Frame, "BEGIN", fun begin_transaction/2, State);
   378 process_command("SEND",
   379 		Frame = #stomp_frame{headers = Headers, body = Body},
   380 		State = #state{ticket = Ticket}) ->
   381     case stomp_frame:header(Frame, "destination") of
   382 	{ok, RoutingKeyStr} ->
   383 	    ExchangeStr = stomp_frame:header(Frame, "exchange", ""),
   384 	    Props = #'P_basic'{
   385 	      content_type = stomp_frame:binary_header(Frame, "content-type", <<"text/plain">>),
   386 	      headers = make_string_table(fun user_header_key/1, Headers),
   387 	      delivery_mode = stomp_frame:integer_header(Frame, "delivery-mode", undefined),
   388 	      priority = stomp_frame:integer_header(Frame, "priority", undefined),
   389 	      correlation_id = stomp_frame:binary_header(Frame, "correlation-id", undefined),
   390 	      reply_to = stomp_frame:binary_header(Frame, "reply-to", undefined),
   391 	      message_id = stomp_frame:binary_header(Frame, "message-id", undefined)
   392 	     },
   393 	    Method = #'basic.publish'{ticket = Ticket,
   394 				      exchange = list_to_binary(ExchangeStr),
   395 				      routing_key = list_to_binary(RoutingKeyStr),
   396 				      mandatory = false,
   397 				      immediate = false},
   398 	    case transactional(Frame) of
   399 		{yes, Transaction} ->
   400 		    extend_transaction(Transaction, {Method, Props, Body}, State);
   401 		no ->
   402 		    {ok, send_method(Method, Props, Body, State)}
   403 	    end;
   404 	not_found ->
   405 	    {ok, send_error("Missing destination",
   406 			    "SEND must include a 'destination', and optional 'exchange' header\n",
   407 			    State)}
   408     end;
   409 process_command("ACK", Frame, State = #state{session_id = SessionId}) ->
   410     case stomp_frame:header(Frame, "message-id") of
   411 	{ok, IdStr} ->
   412 	    IdPrefix = SessionId ++ "_",
   413 	    case string:substr(IdStr, 1, length(IdPrefix)) of
   414 		IdPrefix ->
   415 		    DeliveryTag = list_to_integer(string:substr(IdStr, length(IdPrefix) + 1)),
   416 		    Method = #'basic.ack'{delivery_tag = DeliveryTag,
   417 					  multiple = false},
   418 		    case transactional(Frame) of
   419 			{yes, Transaction} ->
   420 			    extend_transaction(Transaction, {Method}, State);
   421 			no ->
   422 			    {ok, send_method(Method, State)}
   423 		    end;
   424 		_ ->
   425 		    rabbit_misc:die(command_invalid, 'basic.ack')
   426 	    end;
   427 	not_found ->
   428 	    {ok, send_error("Missing message-id",
   429 			    "ACK must include a 'message-id' header\n",
   430 			    State)}
   431     end;
   432 process_command("COMMIT", Frame, State) ->
   433     transactional_action(Frame, "COMMIT", fun commit_transaction/2, State);
   434 process_command("ABORT", Frame, State) ->
   435     transactional_action(Frame, "ABORT", fun abort_transaction/2, State);
   436 process_command("SUBSCRIBE",
   437 		Frame = #stomp_frame{headers = Headers},
   438 		State = #state{ticket = Ticket}) ->
   439     AckMode = case stomp_frame:header(Frame, "ack", "auto") of
   440 		  "auto" -> auto;
   441 		  "client" -> client
   442 	      end,
   443     case stomp_frame:header(Frame, "destination") of
   444 	{ok, QueueStr} ->
   445 	    ConsumerTag = case stomp_frame:header(Frame, "id") of
   446 			      {ok, Str} ->
   447 				  list_to_binary("T_" ++ Str);
   448 			      not_found ->
   449 				  list_to_binary("Q_" ++ QueueStr)
   450 			  end,
   451 	    Queue = list_to_binary(QueueStr),
   452 	    {ok, send_method(#'basic.consume'{ticket = Ticket,
   453 					      queue = Queue,
   454 					      consumer_tag = ConsumerTag,
   455 					      no_local = false,
   456 					      no_ack = (AckMode == auto),
   457 					      exclusive = false,
   458 					      nowait = true},
   459 			     send_method(#'queue.declare'{ticket = Ticket,
   460 							  queue = Queue,
   461 							  passive = false,
   462 							  durable = false,
   463 							  exclusive = falxse,
   464 							  auto_delete = true,
   465 							  nowait = true,
   466 							  arguments =
   467 							    make_string_table(fun user_header_key/1,
   468 									      Headers)},
   469 					 State))};
   470 	not_found ->
   471 	    {ok, send_error("Missing destination",
   472 			    "SUBSCRIBE must include a 'destination' header\n",
   473 			    State)}
   474     end;
   475 process_command("UNSUBSCRIBE", Frame, State) ->
   476     ConsumerTag = case stomp_frame:header(Frame, "id") of
   477 		      {ok, IdStr} ->
   478 			  list_to_binary("T_" ++ IdStr);
   479 		      not_found ->
   480 			  case stomp_frame:header(Frame, "destination") of
   481 			      {ok, QueueStr} ->
   482 				  list_to_binary("Q_" ++ QueueStr);
   483 			      not_found ->
   484 				  missing
   485 			  end
   486 		  end,
   487     if
   488 	ConsumerTag == missing ->
   489 	    {ok, send_error("Missing destination or id",
   490 			    "UNSUBSCRIBE must include a 'destination' or 'id' header\n",
   491 			    State)};
   492 	true ->
   493 	    {ok, send_method(#'basic.cancel'{consumer_tag = ConsumerTag,
   494 					     nowait = true},
   495 			    State)}
   496     end;
   497 process_command("DISCONNECT", _Frame, State) ->
   498     {ok, send_method(#'channel.close'{reply_code = 200, reply_text = <<"">>,
   499 				      class_id = 0, method_id = 0}, State)};
   500 process_command(Command, _Frame, State) ->
   501     {ok, send_error("Bad command",
   502 		    "Could not interpret command " ++ Command ++ "\n",
   503 		    State)}.