diff --git a/include/lhttpc.hrl b/include/lhttpc.hrl new file mode 100644 index 0000000..938454f --- /dev/null +++ b/include/lhttpc.hrl @@ -0,0 +1,34 @@ +%%% ---------------------------------------------------------------------------- +%%% Copyright (c) 2009, Erlang Training and Consulting Ltd. +%%% All rights reserved. +%%% +%%% Redistribution and use in source and binary forms, with or without +%%% modification, are permitted provided that the following conditions are met: +%%% * Redistributions of source code must retain the above copyright +%%% notice, this list of conditions and the following disclaimer. +%%% * Redistributions in binary form must reproduce the above copyright +%%% notice, this list of conditions and the following disclaimer in the +%%% documentation and/or other materials provided with the distribution. +%%% * Neither the name of Erlang Training and Consulting Ltd. nor the +%%% names of its contributors may be used to endorse or promote products +%%% derived from this software without specific prior written permission. +%%% +%%% THIS SOFTWARE IS PROVIDED BY Erlang Training and Consulting Ltd. ''AS IS'' +%%% AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +%%% IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +%%% ARE DISCLAIMED. IN NO EVENT SHALL Erlang Training and Consulting Ltd. BE +%%% LIABLE SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +%%% BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +%%% WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +%%% OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +%%% ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +%%% ---------------------------------------------------------------------------- + +-record(lhttpc_url, { + host :: string(), + port :: integer(), + path :: string(), + is_ssl:: boolean(), + user = "" :: string(), + password = "" :: string() +}). diff --git a/rebar.config b/rebar.config index 51073dd..791c88f 100644 --- a/rebar.config +++ b/rebar.config @@ -2,3 +2,4 @@ {erl_opts, [debug_info]}. {cover_enabled, true}. {dialyzer_opts, [{warnings, [unmatched_returns]}]}. +{eunit_opts, [verbose]}. diff --git a/src/lhttpc.app.src b/src/lhttpc.app.src index 987cb78..1c3d93b 100644 --- a/src/lhttpc.app.src +++ b/src/lhttpc.app.src @@ -34,6 +34,6 @@ {registered, [lhttpc_manager]}, {applications, [kernel, stdlib, ssl, crypto]}, {mod, {lhttpc, nil}}, - {env, [{connection_timeout, 300000}]} + {env, [{connection_timeout, 300000}, {pool_size, 50}]} ]}. diff --git a/src/lhttpc.erl b/src/lhttpc.erl index 974e271..530dd61 100644 --- a/src/lhttpc.erl +++ b/src/lhttpc.erl @@ -44,6 +44,7 @@ ]). -include("lhttpc_types.hrl"). +-include("lhttpc.hrl"). -type result() :: {ok, {{pos_integer(), string()}, headers(), binary()}} | {error, atom()}. @@ -155,13 +156,19 @@ request(URL, Method, Hdrs, Body, Timeout) -> %% {connect_options, [ConnectOptions]} | %% {send_retry, integer()} | %% {partial_upload, WindowSize} | -%% {partial_download, PartialDownloadOptions} +%% {partial_download, PartialDownloadOptions} | +%% {proxy, ProxyUrl} | +%% {proxy_ssl_options, SslOptions} | +%% {pool, LhttcPool} %% Milliseconds = integer() %% ConnectOptions = term() %% WindowSize = integer() | infinity %% PartialDownloadOptions = [PartialDownloadOption] %% PartialDowloadOption = {window_size, WindowSize} | %% {part_size, PartSize} +%% ProxyUrl = string() +%% SslOptions = [any()] +%% LhttcPool = pid() | atom() %% PartSize = integer() | infinity %% Result = {ok, {{StatusCode, ReasonPhrase}, Hdrs, ResponseBody}} | %% {ok, UploadState} | {error, Reason} @@ -171,7 +178,7 @@ request(URL, Method, Hdrs, Body, Timeout) -> %% Reason = connection_closed | connect_timeout | timeout %% @doc Sends a request with a body. %% Would be the same as calling
-%% {Host, Port, Path, Ssl} = lhttpc_lib:parse_url(URL),
+%% #lhttpc_url{host = Host, port = Port, path = Path, is_ssl = Ssl} = lhttpc_lib:parse_url(URL),
 %% request(Host, Port, Path, Ssl, Method, Hdrs, Body, Timeout, Options).
 %% 
%% @@ -182,8 +189,22 @@ request(URL, Method, Hdrs, Body, Timeout) -> -spec request(string(), string() | atom(), headers(), iolist(), pos_integer() | infinity, [option()]) -> result(). request(URL, Method, Hdrs, Body, Timeout, Options) -> - {Host, Port, Path, Ssl} = lhttpc_lib:parse_url(URL), - request(Host, Port, Ssl, Path, Method, Hdrs, Body, Timeout, Options). + #lhttpc_url{ + host = Host, + port = Port, + path = Path, + is_ssl = Ssl, + user = User, + password = Passwd + } = lhttpc_lib:parse_url(URL), + Headers = case User of + "" -> + Hdrs; + _ -> + Auth = "Basic " ++ binary_to_list(base64:encode(User ++ ":" ++ Passwd)), + lists:keystore("Authorization", 1, Hdrs, {"Authorization", Auth}) + end, + request(Host, Port, Ssl, Path, Method, Headers, Body, Timeout, Options). %% @spec (Host, Port, Ssl, Path, Method, Hdrs, RequestBody, Timeout, Options) -> %% Result @@ -202,12 +223,18 @@ request(URL, Method, Hdrs, Body, Timeout, Options) -> %% {connect_options, [ConnectOptions]} | %% {send_retry, integer()} | %% {partial_upload, WindowSize} | -%% {partial_download, PartialDownloadOptions} +%% {partial_download, PartialDownloadOptions} | +%% {proxy, ProxyUrl} | +%% {proxy_ssl_options, SslOptions} | +%% {pool, LhttcPool} %% Milliseconds = integer() %% WindowSize = integer() %% PartialDownloadOptions = [PartialDownloadOption] %% PartialDowloadOption = {window_size, WindowSize} | %% {part_size, PartSize} +%% ProxyUrl = string() +%% SslOptions = [any()] +%% LhttcPool = pid() | atom() %% PartSize = integer() | infinity %% Result = {ok, {{StatusCode, ReasonPhrase}, Hdrs, ResponseBody}} %% | {error, Reason} @@ -314,6 +341,15 @@ request(URL, Method, Hdrs, Body, Timeout, Options) -> %% `undefined'. The functions {@link get_body_part/1} and %% {@link get_body_part/2} can be used to read body parts in the calling %% process. +%% +%% `{proxy, ProxyUrl}' if this option is specified, a proxy server is used as +%% an intermediary for all communication with the destination server. The link +%% to the proxy server is established with the HTTP CONNECT method (RFC2817). +%% Example value: {proxy, "http://john:doe@myproxy.com:3128"} +%% +%% `{proxy_ssl_options, SslOptions}' this is a list of SSL options to use for +%% the SSL session created after the proxy connection is established. For a +%% list of all available options, please check OTP's ssl module manpage. %% @end -spec request(string(), 1..65535, true | false, string(), atom() | string(), headers(), iolist(), pos_integer(), [option()]) -> result(). @@ -553,6 +589,15 @@ verify_options([{partial_download, DownloadOptions} | Options], Errors) verify_options([{connect_options, List} | Options], Errors) when is_list(List) -> verify_options(Options, Errors); +verify_options([{proxy, List} | Options], Errors) + when is_list(List) -> + verify_options(Options, Errors); +verify_options([{proxy_ssl_options, List} | Options], Errors) + when is_list(List) -> + verify_options(Options, Errors); +verify_options([{pool, PidOrName} | Options], Errors) + when is_pid(PidOrName); is_atom(PidOrName) -> + verify_options(Options, Errors); verify_options([Option | Options], Errors) -> verify_options(Options, [Option | Errors]); verify_options([], []) -> diff --git a/src/lhttpc_client.erl b/src/lhttpc_client.erl index 2b53ebe..a0df05d 100644 --- a/src/lhttpc_client.erl +++ b/src/lhttpc_client.erl @@ -34,6 +34,7 @@ -export([request/9]). -include("lhttpc_types.hrl"). +-include("lhttpc.hrl"). -record(client_state, { host :: string(), @@ -52,9 +53,12 @@ upload_window :: non_neg_integer() | infinity, partial_download = false :: true | false, download_window = infinity :: timeout(), - part_size :: non_neg_integer() | infinity + part_size :: non_neg_integer() | infinity, %% in case of infinity we read whatever data we can get from %% the wire at that point or in case of chunked one chunk + proxy :: undefined | #lhttpc_url{}, + proxy_ssl_options = [] :: [any()], + proxy_setup = false :: true | false }). -define(CONNECTION_HDR(HDRS, DEFAULT), @@ -100,10 +104,21 @@ execute(From, Host, Port, Ssl, Path, Method, Hdrs, Body, Options) -> PartialDownload = proplists:is_defined(partial_download, Options), PartialDownloadOptions = proplists:get_value(partial_download, Options, []), NormalizedMethod = lhttpc_lib:normalize_method(Method), + Proxy = case proplists:get_value(proxy, Options) of + undefined -> + undefined; + ProxyUrl when is_list(ProxyUrl), not Ssl -> + % The point of HTTP CONNECT proxying is to use TLS tunneled in + % a plain HTTP/1.1 connection to the proxy (RFC2817). + throw(origin_server_not_https); + ProxyUrl when is_list(ProxyUrl) -> + lhttpc_lib:parse_url(ProxyUrl) + end, {ChunkedUpload, Request} = lhttpc_lib:format_request(Path, NormalizedMethod, Hdrs, Host, Port, Body, PartialUpload), SocketRequest = {socket, self(), Host, Port, Ssl}, - Socket = case gen_server:call(lhttpc_manager, SocketRequest, infinity) of + Pool = proplists:get_value(pool, Options, whereis(lhttpc_manager)), + Socket = case gen_server:call(Pool, SocketRequest, infinity) of {ok, S} -> S; % Re-using HTTP/1.1 connections no_socket -> undefined % Opening a new HTTP/1.1 connection end, @@ -127,7 +142,10 @@ execute(From, Host, Port, Ssl, Path, Method, Hdrs, Body, Options) -> download_window = proplists:get_value(window_size, PartialDownloadOptions, infinity), part_size = proplists:get_value(part_size, - PartialDownloadOptions, infinity) + PartialDownloadOptions, infinity), + proxy = Proxy, + proxy_setup = (Socket =/= undefined), + proxy_ssl_options = proplists:get_value(proxy_ssl_options, Options, []) }, Response = case send_request(State) of {R, undefined} -> @@ -141,11 +159,10 @@ execute(From, Host, Port, Ssl, Path, Method, Hdrs, Body, Options) -> % * The socket was closed remotely already % * Due to an error in this module (returning dead sockets for % instance) - ManagerPid = whereis(lhttpc_manager), - case lhttpc_sock:controlling_process(NewSocket, ManagerPid, Ssl) of + case lhttpc_sock:controlling_process(NewSocket, Pool, Ssl) of ok -> - gen_server:cast(lhttpc_manager, - {done, Host, Port, Ssl, NewSocket}); + DoneMsg = {done, Host, Port, Ssl, NewSocket}, + ok = gen_server:call(Pool, DoneMsg, infinity); _ -> ok end, @@ -157,11 +174,17 @@ send_request(#client_state{attempts = 0}) -> % Don't try again if the number of allowed attempts is 0. throw(connection_closed); send_request(#client_state{socket = undefined} = State) -> - Host = State#client_state.host, - Port = State#client_state.port, - Ssl = State#client_state.ssl, + {Host, Port, Ssl} = request_first_destination(State), Timeout = State#client_state.connect_timeout, - ConnectOptions = State#client_state.connect_options, + ConnectOptions0 = State#client_state.connect_options, + ConnectOptions = case (not lists:member(inet, ConnectOptions0)) andalso + (not lists:member(inet6, ConnectOptions0)) andalso + is_ipv6_host(Host) of + true -> + [inet6 | ConnectOptions0]; + false -> + ConnectOptions0 + end, SocketOptions = [binary, {packet, http}, {active, false} | ConnectOptions], case lhttpc_sock:connect(Host, Port, SocketOptions, Timeout, Ssl) of {ok, Socket} -> @@ -174,6 +197,46 @@ send_request(#client_state{socket = undefined} = State) -> {error, Reason} -> erlang:error(Reason) end; +send_request(#client_state{proxy = #lhttpc_url{}, proxy_setup = false} = State) -> + #lhttpc_url{ + user = User, + password = Passwd, + is_ssl = Ssl + } = State#client_state.proxy, + #client_state{ + host = DestHost, + port = Port, + socket = Socket + } = State, + Host = case inet_parse:address(DestHost) of + {ok, {_, _, _, _, _, _, _, _}} -> + % IPv6 address literals are enclosed by square brackets (RFC2732) + [$[, DestHost, $], $:, integer_to_list(Port)]; + _ -> + [DestHost, $:, integer_to_list(Port)] + end, + ConnectRequest = [ + "CONNECT ", Host, " HTTP/1.1\r\n", + "Host: ", Host, "\r\n", + case User of + "" -> + ""; + _ -> + ["Proxy-Authorization: Basic ", + base64:encode(User ++ ":" ++ Passwd), "\r\n"] + end, + "\r\n" + ], + case lhttpc_sock:send(Socket, ConnectRequest, Ssl) of + ok -> + read_proxy_connect_response(State, nil, nil); + {error, closed} -> + lhttpc_sock:close(Socket, Ssl), + throw(proxy_connection_closed); + {error, Reason} -> + lhttpc_sock:close(Socket, Ssl), + erlang:error(Reason) + end; send_request(State) -> Socket = State#client_state.socket, Ssl = State#client_state.ssl, @@ -196,6 +259,49 @@ send_request(State) -> erlang:error(Reason) end. +request_first_destination(#client_state{proxy = #lhttpc_url{} = Proxy}) -> + {Proxy#lhttpc_url.host, Proxy#lhttpc_url.port, Proxy#lhttpc_url.is_ssl}; +request_first_destination(#client_state{host = Host, port = Port, ssl = Ssl}) -> + {Host, Port, Ssl}. + +read_proxy_connect_response(State, StatusCode, StatusText) -> + Socket = State#client_state.socket, + ProxyIsSsl = (State#client_state.proxy)#lhttpc_url.is_ssl, + case lhttpc_sock:recv(Socket, ProxyIsSsl) of + {ok, {http_response, _Vsn, Code, Reason}} -> + read_proxy_connect_response(State, Code, Reason); + {ok, {http_header, _, _Name, _, _Value}} -> + read_proxy_connect_response(State, StatusCode, StatusText); + {ok, http_eoh} when StatusCode >= 100, StatusCode =< 199 -> + % RFC 2616, section 10.1: + % A client MUST be prepared to accept one or more + % 1xx status responses prior to a regular + % response, even if the client does not expect a + % 100 (Continue) status message. Unexpected 1xx + % status responses MAY be ignored by a user agent. + read_proxy_connect_response(State, nil, nil); + {ok, http_eoh} when StatusCode >= 200, StatusCode < 300 -> + % RFC2817, any 2xx code means success. + ConnectOptions = State#client_state.connect_options, + SslOptions = State#client_state.proxy_ssl_options, + Timeout = State#client_state.connect_timeout, + State2 = case ssl:connect(Socket, SslOptions ++ ConnectOptions, Timeout) of + {ok, SslSocket} -> + State#client_state{socket = SslSocket, proxy_setup = true}; + {error, Reason} -> + lhttpc_sock:close(Socket, ProxyIsSsl), + erlang:error({proxy_connection_failed, Reason}) + end, + send_request(State2); + {ok, http_eoh} -> + throw({proxy_connection_refused, StatusCode, StatusText}); + {error, closed} -> + lhttpc_sock:close(Socket, ProxyIsSsl), + throw(proxy_connection_closed); + {error, Reason} -> + erlang:error({proxy_connection_failed, Reason}) + end. + partial_upload(State) -> Response = {ok, {self(), State#client_state.upload_window}}, State#client_state.requester ! {response, self(), Response}, @@ -664,3 +770,24 @@ maybe_close_socket(Socket, Ssl, _, ReqHdrs, RespHdrs) -> ClientConnection =/= "close", ServerConnection =:= "keep-alive" -> Socket end. + +is_ipv6_host(Host) -> + case inet_parse:address(Host) of + {ok, {_, _, _, _, _, _, _, _}} -> + true; + {ok, {_, _, _, _}} -> + false; + _ -> + % Prefer IPv4 over IPv6. + case inet:getaddr(Host, inet) of + {ok, _} -> + false; + _ -> + case inet:getaddr(Host, inet6) of + {ok, _} -> + true; + _ -> + false + end + end + end. diff --git a/src/lhttpc_lib.erl b/src/lhttpc_lib.erl index 07eb6ce..88ed27f 100644 --- a/src/lhttpc_lib.erl +++ b/src/lhttpc_lib.erl @@ -42,6 +42,7 @@ -export([format_hdrs/1, dec/1]). -include("lhttpc_types.hrl"). +-include("lhttpc.hrl"). %% @spec header_value(Header, Headers) -> undefined | term() %% Header = string() @@ -91,26 +92,63 @@ maybe_atom_to_list(Atom) when is_atom(Atom) -> maybe_atom_to_list(List) when is_list(List) -> List. -%% @spec (URL) -> {Host, Port, Path, Ssl} +%% @spec (URL) -> #lhttpc_url{} %% URL = string() -%% Host = string() -%% Port = integer() -%% Path = string() -%% Ssl = boolean() %% @doc --spec parse_url(string()) -> {string(), integer(), string(), boolean()}. +-spec parse_url(string()) -> #lhttpc_url{}. parse_url(URL) -> % XXX This should be possible to do with the re module? - {Scheme, HostPortPath} = split_scheme(URL), + {Scheme, CredsHostPortPath} = split_scheme(URL), + {User, Passwd, HostPortPath} = split_credentials(CredsHostPortPath), {Host, PortPath} = split_host(HostPortPath, []), {Port, Path} = split_port(Scheme, PortPath, []), - {string:to_lower(Host), Port, Path, Scheme =:= https}. + #lhttpc_url{ + host = string:to_lower(Host), + port = Port, + path = Path, + user = User, + password = Passwd, + is_ssl = (Scheme =:= https) + }. split_scheme("http://" ++ HostPortPath) -> {http, HostPortPath}; split_scheme("https://" ++ HostPortPath) -> {https, HostPortPath}. +split_credentials(CredsHostPortPath) -> + case string:tokens(CredsHostPortPath, "@") of + [HostPortPath] -> + {"", "", HostPortPath}; + [Creds, HostPortPath] -> + % RFC1738 (section 3.1) says: + % "The user name (and password), if present, are followed by a + % commercial at-sign "@". Within the user and password field, any ":", + % "@", or "/" must be encoded." + % The mentioned encoding is the "percent" encoding. + case string:tokens(Creds, ":") of + [User] -> + % RFC1738 says ":password" is optional + {User, "", HostPortPath}; + [User, Passwd] -> + {User, Passwd, HostPortPath} + end + end. + +split_host("[" ++ Rest, []) -> + % IPv6 address literals are enclosed by square brackets (RFC2732) + case string:str(Rest, "]") of + 0 -> + split_host(Rest, "["); + N -> + {IPv6Address, "]" ++ PortPath0} = lists:split(N - 1, Rest), + case PortPath0 of + ":" ++ PortPath -> + {IPv6Address, PortPath}; + _ -> + {IPv6Address, PortPath0} + end + end; split_host([$: | PortPath], Host) -> {lists:reverse(Host), PortPath}; split_host([$/ | _] = PortPath, Host) -> @@ -250,5 +288,17 @@ is_chunked(Hdrs) -> dec(Num) when is_integer(Num) -> Num - 1; dec(Else) -> Else. -host(Host, 80) -> Host; -host(Host, Port) -> [Host, $:, integer_to_list(Port)]. +host(Host, 80) -> maybe_ipv6_enclose(Host); +% When proxying after an HTTP CONNECT session is established, squid doesn't +% like the :443 suffix in the Host header. +host(Host, 443) -> maybe_ipv6_enclose(Host); +host(Host, Port) -> [maybe_ipv6_enclose(Host), $:, integer_to_list(Port)]. + +maybe_ipv6_enclose(Host) -> + case inet_parse:address(Host) of + {ok, {_, _, _, _, _, _, _, _}} -> + % IPv6 address literals are enclosed by square brackets (RFC2732) + [$[, Host, $]]; + _ -> + Host + end. diff --git a/src/lhttpc_manager.erl b/src/lhttpc_manager.erl index 4ac4ec2..bd58845 100644 --- a/src/lhttpc_manager.erl +++ b/src/lhttpc_manager.erl @@ -25,6 +25,7 @@ %%% ---------------------------------------------------------------------------- %%% @author Oscar Hellström +%%% @author Filipe David Manana %%% @doc Connection manager for the HTTP client. %%% This gen_server is responsible for keeping track of persistent %%% connections to HTTP servers. The only interesting API is @@ -35,9 +36,10 @@ -export([ start_link/0, - connection_count/0, + start_link/1, connection_count/1, - update_connection_timeout/1 + connection_count/2, + update_connection_timeout/2 ]). -export([ init/1, @@ -53,19 +55,23 @@ -record(httpc_man, { destinations = dict:new(), sockets = dict:new(), + clients = dict:new(), % Pid => {Dest, MonRef} + queues = dict:new(), % Dest => queue of Froms + max_pool_size = 50 :: non_neg_integer(), timeout = 300000 :: non_neg_integer() }). -%% @spec () -> Count +%% @spec (PoolPidOrName) -> Count %% Count = integer() %% @doc Returns the total number of active connections maintained by the -%% httpc manager. +%% specified lhttpc pool (manager). %% @end --spec connection_count() -> non_neg_integer(). -connection_count() -> - gen_server:call(?MODULE, connection_count). +-spec connection_count(pid() | atom()) -> non_neg_integer(). +connection_count(PidOrName) -> + gen_server:call(PidOrName, connection_count). -%% @spec (Destination) -> Count +%% @spec (PoolPidOrName, Destination) -> Count +%% PoolPidOrName = pid() | atom() %% Destination = {Host, Port, Ssl} %% Host = string() %% Port = integer() @@ -74,43 +80,74 @@ connection_count() -> %% @doc Returns the number of active connections to the specific %% `Destination' maintained by the httpc manager. %% @end --spec connection_count({string(), pos_integer(), boolean()}) -> +-spec connection_count(pid() | atom(), {string(), pos_integer(), boolean()}) -> non_neg_integer(). -connection_count({Host, Port, Ssl}) -> +connection_count(PidOrName, {Host, Port, Ssl}) -> Destination = {string:to_lower(Host), Port, Ssl}, - gen_server:call(?MODULE, {connection_count, Destination}). + gen_server:call(PidOrName, {connection_count, Destination}). -%% @spec (Timeout) -> ok +%% @spec (PoolPidOrName, Timeout) -> ok +%% PoolPidOrName = pid() | atom() %% Timeout = integer() %% @doc Updates the timeout for persistent connections. %% This will only affect future sockets handed to the manager. The sockets %% already managed will keep their timers. %% @end --spec update_connection_timeout(non_neg_integer()) -> ok. -update_connection_timeout(Milliseconds) -> - gen_server:cast(?MODULE, {update_timeout, Milliseconds}). +-spec update_connection_timeout(pid() | atom(), non_neg_integer()) -> ok. +update_connection_timeout(PidOrName, Milliseconds) -> + gen_server:cast(PidOrName, {update_timeout, Milliseconds}). %% @spec () -> {ok, pid()} %% @doc Starts and link to the gen server. %% This is normally called by a supervisor. %% @end --spec start_link() -> {ok, pid()} | {error, allready_started}. +-spec start_link() -> {ok, pid()} | {error, already_started}. start_link() -> - gen_server:start_link({local, ?MODULE}, ?MODULE, nil, []). + start_link([]). + +-spec start_link([{atom(), non_neg_integer()}]) -> + {ok, pid()} | {error, already_started}. +start_link(Options0) -> + Options = maybe_apply_defaults([connection_timeout, pool_size], Options0), + case proplists:get_value(name, Options) of + undefined -> + gen_server:start_link(?MODULE, Options, []); + Name -> + gen_server:start_link({local, Name}, ?MODULE, Options, []) + end. %% @hidden -spec init(any()) -> {ok, #httpc_man{}}. -init(_) -> +init(Options) -> process_flag(priority, high), - {ok, Timeout} = application:get_env(lhttpc, connection_timeout), - {ok, #httpc_man{timeout = Timeout}}. + Timeout = proplists:get_value(connection_timeout, Options), + Size = proplists:get_value(pool_size, Options), + {ok, #httpc_man{timeout = Timeout, max_pool_size = Size}}. %% @hidden -spec handle_call(any(), any(), #httpc_man{}) -> {reply, any(), #httpc_man{}}. -handle_call({socket, Pid, Host, Port, Ssl}, _, State) -> - {Reply, NewState} = find_socket({Host, Port, Ssl}, Pid, State), - {reply, Reply, NewState}; +handle_call({socket, Pid, Host, Port, Ssl}, {Pid, _Ref} = From, State) -> + #httpc_man{ + max_pool_size = MaxSize, + clients = Clients, + queues = Queues + } = State, + Dest = {Host, Port, Ssl}, + {Reply0, State2} = find_socket(Dest, Pid, State), + case Reply0 of + {ok, _Socket} -> + State3 = monitor_client(Dest, From, State2), + {reply, Reply0, State3}; + no_socket -> + case dict:size(Clients) >= MaxSize of + true -> + Queues2 = add_to_queue(Dest, From, Queues), + {noreply, State2#httpc_man{queues = Queues2}}; + false -> + {reply, no_socket, monitor_client(Dest, From, State2)} + end + end; handle_call(connection_count, _, State) -> {reply, dict:size(State#httpc_man.sockets), State}; handle_call({connection_count, Destination}, _, State) -> @@ -119,14 +156,19 @@ handle_call({connection_count, Destination}, _, State) -> error -> 0 end, {reply, Count, State}; +handle_call({done, Host, Port, Ssl, Socket}, {Pid, _} = From, State) -> + gen_server:reply(From, ok), + Dest = {Host, Port, Ssl}, + {Dest, MonRef} = dict:fetch(Pid, State#httpc_man.clients), + true = erlang:demonitor(MonRef, [flush]), + Clients2 = dict:erase(Pid, State#httpc_man.clients), + State2 = deliver_socket(Socket, Dest, State#httpc_man{clients = Clients2}), + {noreply, State2}; handle_call(_, _, State) -> {reply, {error, unknown_request}, State}. %% @hidden -spec handle_cast(any(), #httpc_man{}) -> {noreply, #httpc_man{}}. -handle_cast({done, Host, Port, Ssl, Socket}, State) -> - NewState = store_socket({Host, Port, Ssl}, Socket, State), - {noreply, NewState}; handle_cast({update_timeout, Milliseconds}, State) -> {noreply, State#httpc_man{timeout = Milliseconds}}; handle_cast(_, State) -> @@ -148,6 +190,17 @@ handle_info({tcp, Socket, _}, State) -> {noreply, remove_socket(Socket, State)}; % got garbage handle_info({ssl, Socket, _}, State) -> {noreply, remove_socket(Socket, State)}; % got garbage +handle_info({'DOWN', MonRef, process, Pid, _Reason}, State) -> + {Dest, MonRef} = dict:fetch(Pid, State#httpc_man.clients), + Clients2 = dict:erase(Pid, State#httpc_man.clients), + case queue_out(Dest, State#httpc_man.queues) of + empty -> + {noreply, State#httpc_man{clients = Clients2}}; + {ok, From, Queues2} -> + gen_server:reply(From, no_socket), + State2 = State#httpc_man{queues = Queues2, clients = Clients2}, + {noreply, monitor_client(Dest, From, State2)} + end; handle_info(_, State) -> {noreply, State}. @@ -236,3 +289,61 @@ cancel_timer(Timer, Socket) -> end; _ -> ok end. + +add_to_queue({_Host, _Port, _Ssl} = Dest, From, Queues) -> + case dict:find(Dest, Queues) of + error -> + dict:store(Dest, queue:in(From, queue:new()), Queues); + {ok, Q} -> + dict:store(Dest, queue:in(From, Q), Queues) + end. + +queue_out({_Host, _Port, _Ssl} = Dest, Queues) -> + case dict:find(Dest, Queues) of + error -> + empty; + {ok, Q} -> + {{value, From}, Q2} = queue:out(Q), + Queues2 = case queue:is_empty(Q2) of + true -> + dict:erase(Dest, Queues); + false -> + dict:store(Dest, Q2, Queues) + end, + {ok, From, Queues2} + end. + +deliver_socket(Socket, {_, _, Ssl} = Dest, State) -> + case queue_out(Dest, State#httpc_man.queues) of + empty -> + store_socket(Dest, Socket, State); + {ok, {PidWaiter, _} = FromWaiter, Queues2} -> + lhttpc_sock:setopts(Socket, [{active, false}], Ssl), + case lhttpc_sock:controlling_process(Socket, PidWaiter, Ssl) of + ok -> + gen_server:reply(FromWaiter, {ok, Socket}), + monitor_client(Dest, FromWaiter, State#httpc_man{queues = Queues2}); + {error, badarg} -> % Pid died, reuse for someone else + lhttpc_sock:setopts(Socket, [{active, once}], Ssl), + deliver_socket(Socket, Dest, State#httpc_man{queues = Queues2}); + _ -> % Something wrong with the socket; just remove it + catch lhttpc_sock:close(Socket, Ssl), + State + end + end. + +monitor_client(Dest, {Pid, _} = _From, State) -> + MonRef = erlang:monitor(process, Pid), + Clients2 = dict:store(Pid, {Dest, MonRef}, State#httpc_man.clients), + State#httpc_man{clients = Clients2}. + +maybe_apply_defaults([], Options) -> + Options; +maybe_apply_defaults([OptName | Rest], Options) -> + case proplists:is_defined(OptName, Options) of + true -> + maybe_apply_defaults(Rest, Options); + false -> + {ok, Default} = application:get_env(lhttpc, OptName), + maybe_apply_defaults(Rest, [{OptName, Default} | Options]) + end. diff --git a/src/lhttpc_sup.erl b/src/lhttpc_sup.erl index e96a717..601934c 100644 --- a/src/lhttpc_sup.erl +++ b/src/lhttpc_sup.erl @@ -51,7 +51,7 @@ start_link() -> %% @hidden -spec init(any()) -> {ok, {{atom(), integer(), integer()}, [child()]}}. init(_) -> - LHTTPCManager = {lhttpc_manager, {lhttpc_manager, start_link, []}, + LHTTPCManager = {lhttpc_manager, {lhttpc_manager, start_link, [[{name, lhttpc_manager}]]}, permanent, 10000, worker, [lhttpc_manager] }, {ok, {{one_for_one, 10, 1}, [LHTTPCManager]}}. diff --git a/test/lhttpc_lib_tests.erl b/test/lhttpc_lib_tests.erl index 493803d..bd00bd8 100644 --- a/test/lhttpc_lib_tests.erl +++ b/test/lhttpc_lib_tests.erl @@ -27,24 +27,179 @@ %%% @author Oscar Hellström -module(lhttpc_lib_tests). +-include("../src/lhttpc_types.hrl"). +-include("../include/lhttpc.hrl"). -include_lib("eunit/include/eunit.hrl"). parse_url_test_() -> [ - ?_assertEqual({"host", 80, "/", false}, - lhttpc_lib:parse_url("http://host")), - ?_assertEqual({"host", 80, "/", false}, - lhttpc_lib:parse_url("http://host/")), - ?_assertEqual({"host", 443, "/", true}, - lhttpc_lib:parse_url("https://host")), - ?_assertEqual({"host", 443, "/", true}, - lhttpc_lib:parse_url("https://host/")), - ?_assertEqual({"host", 180, "/", false}, - lhttpc_lib:parse_url("http://host:180")), - ?_assertEqual({"host", 180, "/", false}, - lhttpc_lib:parse_url("http://host:180/")), - ?_assertEqual({"host", 180, "/foo", false}, - lhttpc_lib:parse_url("http://host:180/foo")), - ?_assertEqual({"host", 180, "/foo/bar", false}, - lhttpc_lib:parse_url("http://host:180/foo/bar")) + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 80, + path = "/", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://host")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 80, + path = "/", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://host/")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 443, + path = "/", + is_ssl = true, + user = "", + password = "" + }, + lhttpc_lib:parse_url("https://host")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 443, + path = "/", + is_ssl = true, + user = "", + password = "" + }, + lhttpc_lib:parse_url("https://host/")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://host:180")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://host:180/")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/foo", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://host:180/foo")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://host:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "joe", + password = "erlang" + }, + lhttpc_lib:parse_url("http://joe:erlang@host:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "joe", + password = "" + }, + lhttpc_lib:parse_url("http://joe@host:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://@host:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "joe%3Aarm", + password = "erlang" + }, + lhttpc_lib:parse_url("http://joe%3Aarm:erlang@host:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "host", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "joe%3aarm", + password = "erlang%2Fotp" + }, + lhttpc_lib:parse_url("http://joe%3aarm:erlang%2Fotp@host:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "::1", + port = 80, + path = "/foo/bar", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://[::1]/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "::1", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "", + password = "" + }, + lhttpc_lib:parse_url("http://[::1]:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "::1", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "joe", + password = "erlang" + }, + lhttpc_lib:parse_url("http://joe:erlang@[::1]:180/foo/bar")), + + ?_assertEqual(#lhttpc_url{ + host = "1080:0:0:0:8:800:200c:417a", + port = 180, + path = "/foo/bar", + is_ssl = false, + user = "joe", + password = "erlang" + }, + lhttpc_lib:parse_url("http://joe:erlang@[1080:0:0:0:8:800:200C:417A]:180/foo/bar")) ]. diff --git a/test/lhttpc_manager_tests.erl b/test/lhttpc_manager_tests.erl index 67a2344..4e24afe 100644 --- a/test/lhttpc_manager_tests.erl +++ b/test/lhttpc_manager_tests.erl @@ -25,12 +25,12 @@ %%% ---------------------------------------------------------------------------- %%% @author Oscar Hellström +%%% @author Filipe David Manana -module(lhttpc_manager_tests). -include_lib("eunit/include/eunit.hrl"). -define(HOST, "www.example.com"). --define(PORT, 666). -define(SSL, false). %%% Eunit setup stuff @@ -38,6 +38,8 @@ start_app() -> application:start(public_key), ok = application:start(ssl), + _ = application:load(lhttpc), + ok = application:set_env(lhttpc, pool_size, 3), ok = application:start(lhttpc). stop_app(_) -> @@ -49,71 +51,373 @@ manager_test_() -> {setup, fun start_app/0, fun stop_app/1, [ ?_test(empty_manager()), ?_test(one_socket()), - ?_test(many_sockets()), - ?_test(closed_race_cond()) + {timeout, 60, ?_test(connection_timeout())}, + {timeout, 60, ?_test(many_sockets())}, + {timeout, 60, ?_test(closed_race_cond())} ]} }. %%% Tests empty_manager() -> - ?assertEqual(no_socket, gen_server:call(lhttpc_manager, - {socket, self(), ?HOST, ?PORT, ?SSL})). + LS = socket_server:listen(), + link(whereis(lhttpc_manager)), % want to make sure it doesn't crash + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + Client = spawn_client(), + ?assertEqual(ok, ping_client(Client)), + + ?assertEqual(no_socket, client_peek_socket(Client)), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + ?assertEqual(ok, stop_client(Client)), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + catch gen_tcp:close(LS), + unlink(whereis(lhttpc_manager)), + ok. one_socket() -> - {LS, Socket} = socket_server:open(), - gen_tcp:close(LS), % no use of this - give_away(Socket, ?HOST, ?PORT, ?SSL), - ?assertEqual({ok, Socket}, gen_server:call(lhttpc_manager, - {socket, self(), ?HOST, ?PORT, ?SSL})), - ?assertEqual(no_socket, gen_server:call(lhttpc_manager, - {socket, self(), ?HOST, ?PORT, ?SSL})), - gen_tcp:close(Socket). + LS = socket_server:listen(), + link(whereis(lhttpc_manager)), % want to make sure it doesn't crash + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + Client1 = spawn_client(), + ?assertEqual(ok, ping_client(Client1)), + Client2 = spawn_client(), + ?assertEqual(ok, ping_client(Client2)), + + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + Result1 = connect_client(Client1), + ?assertMatch({ok, _}, Result1), + {ok, Socket} = Result1, + ?assertEqual(ok, ping_client(Client1)), + + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + ?assertEqual(ok, disconnect_client(Client1)), + ?assertEqual(ok, ping_client(Client1)), + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + + Result2 = connect_client(Client2), + ?assertEqual({ok, Socket}, Result2), + ?assertEqual(ok, ping_client(Client2)), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + ?assertEqual(ok, stop_client(Client1)), + ?assertEqual(ok, stop_client(Client2)), + catch gen_tcp:close(LS), + unlink(whereis(lhttpc_manager)), + ok. + +connection_timeout() -> + LS = socket_server:listen(), + link(whereis(lhttpc_manager)), % want to make sure it doesn't crash + ok = lhttpc_manager:update_connection_timeout(lhttpc_manager, 3000), + erlang:yield(), % make sure lhttpc_manager processes the message + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + Client = spawn_client(), + ?assertEqual(ok, ping_client(Client)), + + Result1 = connect_client(Client), + ?assertMatch({ok, _}, Result1), + {ok, Socket} = Result1, + ?assertEqual(ok, ping_client(Client)), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + ?assertEqual(ok, disconnect_client(Client)), + ?assertEqual(ok, ping_client(Client)), + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + + % sleep a while and verify the socket was closed by lhttpc_manager + ok = timer:sleep(3100), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + Result2 = connect_client(Client), + ?assertMatch({ok, _}, Result2), + {ok, Socket2} = Result2, + ?assertEqual(ok, ping_client(Client)), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + ?assert(Socket2 =/= Socket), + + ?assertEqual(ok, disconnect_client(Client)), + ?assertEqual(ok, ping_client(Client)), + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + + catch gen_tcp:close(LS), + ?assertEqual(ok, stop_client(Client)), + unlink(whereis(lhttpc_manager)), + ok. many_sockets() -> - {LS, Socket1} = socket_server:open(), - {ok, Port} = inet:port(LS), - Pid2 = socket_server:accept(LS), - Pid3 = socket_server:accept(LS), - Socket2 = socket_server:connect(Pid2, Port), - Socket3 = socket_server:connect(Pid3, Port), - gen_tcp:close(LS), - give_away(Socket1, ?HOST, ?PORT, ?SSL), - give_away(Socket2, ?HOST, ?PORT, ?SSL), - give_away(Socket3, ?HOST, ?PORT, ?SSL), - ?assertEqual({ok, Socket3}, gen_server:call(lhttpc_manager, - {socket, self(), ?HOST, ?PORT, ?SSL})), - ?assertEqual({ok, Socket2}, gen_server:call(lhttpc_manager, - {socket, self(), ?HOST, ?PORT, ?SSL})), - ?assertEqual({ok, Socket1}, gen_server:call(lhttpc_manager, - {socket, self(), ?HOST, ?PORT, ?SSL})), link(whereis(lhttpc_manager)), % want to make sure it doesn't crash + LS = socket_server:listen(), + Client1 = spawn_client(), + Client2 = spawn_client(), + Client3 = spawn_client(), + ?assertEqual(ok, ping_client(Client1)), + ?assertEqual(ok, ping_client(Client2)), + ?assertEqual(ok, ping_client(Client3)), + + _Acceptor1 = socket_server:accept(LS), + Result1 = connect_client(Client1), + ?assertMatch({ok, _}, Result1), + {ok, Socket1} = Result1, + ?assertEqual(ok, ping_client(Client1)), + + ?assertEqual(ok, disconnect_client(Client1)), + ?assertEqual(ok, ping_client(Client1)), + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + + Result2 = connect_client(Client2), + ?assertMatch({ok, Socket1}, Result2), + ?assertEqual(ok, ping_client(Client2)), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + ?assertEqual(ok, disconnect_client(Client2)), + ?assertEqual(ok, ping_client(Client2)), + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + lhttpc_manager ! {tcp_closed, Socket1}, - ?assertEqual(0, lhttpc_manager:connection_count()), + _Acceptor2 = socket_server:accept(LS), + Result3 = connect_client(Client1), + ?assertMatch({ok, _}, Result3), + {ok, Socket2} = Result3, + ?assertEqual(ok, ping_client(Client1)), + ?assertNot(lists:member(Socket2, [Socket1])), + + Result4 = connect_client(Client2), + ?assertMatch({ok, _}, Result4), + {ok, Socket3} = Result4, + ?assertEqual(ok, ping_client(Client2)), + ?assertNot(lists:member(Socket3, [Socket1, Socket2])), + + Result5 = connect_client(Client3), + ?assertMatch({ok, _}, Result5), + {ok, Socket4} = Result5, + ?assertEqual(ok, ping_client(Client3)), + ?assertNot(lists:member(Socket4, [Socket1, Socket2, Socket3])), + + Client4 = spawn_client(), + ?assertEqual(ok, ping_client(Client4)), + Result6 = connect_client(Client4), + ?assertMatch(timeout, Result6), + ?assertEqual(timeout, ping_client(Client4)), + + ?assertEqual(ok, disconnect_client(Client1)), + ?assertEqual(ok, ping_client(Client1)), + % 0 because the connection should be delivered to blocked client Client4 + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + Result7 = get_client_socket(Client4), + ?assertMatch({ok, _}, Result7), + {ok, Socket5} = Result7, + ?assertEqual(ok, ping_client(Client4)), + ?assertEqual(Socket2, Socket5), + + % If a blocked client dies, verify that the pool doesn't + % send a socket to it. + Client5 = spawn_client(), + Client6 = spawn_client(), + ?assertEqual(ok, ping_client(Client5)), + ?assertEqual(ok, ping_client(Client6)), + ?assertEqual(timeout, connect_client(Client5)), + ?assertEqual(timeout, connect_client(Client6)), + ?assertEqual(timeout, ping_client(Client5)), + ?assertEqual(timeout, ping_client(Client6)), + + exit(Client5, kill), + + ?assertEqual(ok, disconnect_client(Client4)), + ?assertEqual(ok, ping_client(Client4)), + % 0 because the connection should be delivered to blocked client Client6 + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + Result8 = get_client_socket(Client6), + ?assertMatch({ok, _}, Result8), + {ok, Socket6} = Result8, + ?assertEqual(ok, ping_client(Client6)), + ?assertEqual(Socket6, Socket5), + + % If a client holding a socket dies, without returning it to the pool, + % a blocked client will be unblocked + Client7 = spawn_client(), + ?assertEqual(ok, ping_client(Client7)), + ?assertEqual(timeout, connect_client(Client7)), + ?assertEqual(timeout, ping_client(Client7)), + + exit(Client6, kill), + Result9 = get_client_socket(Client7), + ?assertMatch({ok, _}, Result9), + {ok, Socket7} = Result9, + ?assertEqual(ok, ping_client(Client7)), + ?assertNot(lists:member( + Socket7, [Socket1, Socket2, Socket3, Socket4, Socket5, Socket6])), + + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + ?assertEqual(ok, disconnect_client(Client2)), + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + ?assertEqual(ok, disconnect_client(Client3)), + ?assertEqual(2, lhttpc_manager:connection_count(lhttpc_manager)), + ?assertEqual(ok, disconnect_client(Client7)), + ?assertEqual(3, lhttpc_manager:connection_count(lhttpc_manager)), + + lhttpc_manager ! {tcp_closed, Socket7}, + ?assertEqual(2, lhttpc_manager:connection_count(lhttpc_manager)), + lhttpc_manager ! {tcp_closed, Socket4}, + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + lhttpc_manager ! {tcp_closed, Socket3}, + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + catch gen_tcp:close(LS), + ?assertEqual(ok, stop_client(Client1)), + ?assertEqual(ok, stop_client(Client2)), + ?assertEqual(ok, stop_client(Client3)), + ?assertEqual(ok, stop_client(Client4)), + ?assertEqual(ok, stop_client(Client7)), unlink(whereis(lhttpc_manager)), - gen_tcp:close(Socket1), - gen_tcp:close(Socket2), - gen_tcp:close(Socket3). + ok. closed_race_cond() -> - {LS, Socket} = socket_server:open(), - gen_tcp:close(LS), % no use of this - give_away(Socket, ?HOST, ?PORT, ?SSL), - Pid = self(), + LS = socket_server:listen(), + link(whereis(lhttpc_manager)), % want to make sure it doesn't crash + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + Client = spawn_client(), + ?assertEqual(ok, ping_client(Client)), + + Result1 = connect_client(Client), + ?assertMatch({ok, _}, Result1), + {ok, Socket} = Result1, + ?assertEqual(ok, ping_client(Client)), + + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + ?assertEqual(ok, disconnect_client(Client)), + ?assertEqual(ok, ping_client(Client)), + ?assertEqual(1, lhttpc_manager:connection_count(lhttpc_manager)), + ManagerPid = whereis(lhttpc_manager), true = erlang:suspend_process(ManagerPid), - spawn_link(fun() -> - Pid ! {result, gen_server:call(lhttpc_manager, - {socket, self(), ?HOST, ?PORT, ?SSL})} - end), + + Pid = self(), + spawn_link(fun() -> + Pid ! {result, client_peek_socket(Client)} + end), + erlang:yield(), % make sure that the spawned process has run gen_tcp:close(Socket), % a closed message should be sent to the manager true = erlang:resume_process(ManagerPid), - Result = receive {result, R} -> R end, - ?assertEqual(no_socket, Result). + + Result2 = receive + {result, R} -> R + after 5000 -> erlang:error("Timeout receiving result from child process") + end, + + ?assertMatch(no_socket, Result2), + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)), + + ?assertEqual(ok, stop_client(Client)), + catch gen_tcp:close(LS), + unlink(whereis(lhttpc_manager)), + ok. %%% Helpers functions -give_away(Socket, Host, Port, Ssl) -> - gen_tcp:controlling_process(Socket, whereis(lhttpc_manager)), - gen_server:cast(lhttpc_manager, {done, Host, Port, Ssl, Socket}). +spawn_client() -> + Parent = self(), + spawn(fun() -> client_loop(Parent, nil) end). + +client_loop(Parent, Socket) -> + receive + stop -> + catch gen_tcp:close(Socket), + Parent ! stopped; + {get_socket, Ref} -> + Parent ! {socket, Ref, Socket}, + client_loop(Parent, Socket); + {ping, Ref} -> + Parent ! {pong, Ref}, + client_loop(Parent, Socket); + {peek_socket, Ref, To} -> + Args = {socket, self(), ?HOST, get_port(), ?SSL}, + Result = gen_server:call(lhttpc_manager, Args, infinity), + To ! {result, Ref, Result}, + client_loop(Parent, Socket); + {connect, Ref} -> + Args = {socket, self(), ?HOST, get_port(), ?SSL}, + NewSocket = case gen_server:call(lhttpc_manager, Args, infinity) of + no_socket -> + socket_server:connect(get_port()); + {ok, S} -> + S + end, + Parent ! {connected, Ref, NewSocket}, + client_loop(Parent, NewSocket); + {return_socket, Ref} -> + gen_tcp:controlling_process(Socket, whereis(lhttpc_manager)), + ok = gen_server:call(lhttpc_manager, {done, ?HOST, get_port(), ?SSL, Socket}), + Parent ! {returned, Ref}, + client_loop(Parent, nil) + end. + +ping_client(Client) -> + Ref = make_ref(), + Client ! {ping, Ref}, + receive + {pong, Ref} -> + ok + after 2000 -> + timeout + end. + +connect_client(Client) -> + Ref = make_ref(), + Client ! {connect, Ref}, + receive + {connected, Ref, Socket} -> + {ok, Socket} + after 2000 -> + timeout + end. + +disconnect_client(Client) -> + Ref = make_ref(), + Client ! {return_socket, Ref}, + receive + {returned, Ref} -> + ok + after 2000 -> + timeout + end. + +get_client_socket(Client) -> + Ref = make_ref(), + Client ! {get_socket, Ref}, + receive + {socket, Ref, Socket} -> + {ok, Socket} + after 2000 -> + timeout + end. + +client_peek_socket(Client) -> + Ref = make_ref(), + Client ! {peek_socket, Ref, self()}, + receive + {result, Ref, Result} -> + Result + after 5000 -> + timeout + end. + +stop_client(Client) -> + Client ! stop, + receive + stopped -> + ok + after 2000 -> + timeout + end. + +get_port() -> + {ok, P} = application:get_env(lhttpc, test_port), + P. diff --git a/test/lhttpc_tests.erl b/test/lhttpc_tests.erl index 397741b..528097f 100644 --- a/test/lhttpc_tests.erl +++ b/test/lhttpc_tests.erl @@ -28,7 +28,7 @@ -module(lhttpc_tests). -export([test_no/2]). --import(webserver, [start/2]). +-import(webserver, [start/2, start/3]). -include_lib("eunit/include/eunit.hrl"). @@ -112,7 +112,11 @@ tcp_test_() -> {inorder, {setup, fun start_app/0, fun stop_app/1, [ ?_test(simple_get()), + ?_test(simple_get_ipv6()), ?_test(empty_get()), + ?_test(basic_auth()), + ?_test(missing_basic_auth()), + ?_test(wrong_basic_auth()), ?_test(get_with_mandatory_hdrs()), ?_test(get_with_connect_options()), ?_test(no_content_length()), @@ -161,6 +165,7 @@ ssl_test_() -> {inorder, {setup, fun start_app/0, fun stop_app/1, [ ?_test(ssl_get()), + ?_test(ssl_get_ipv6()), ?_test(ssl_post()), ?_test(ssl_chunked()), ?_test(connection_count()) % just check that it's 0 (last) @@ -181,6 +186,10 @@ simple_get() -> simple(get), simple("GET"). +simple_get_ipv6() -> + simple(get, inet6), + simple("GET", inet6). + empty_get() -> Port = start(gen_tcp, [fun empty_body/5]), URL = url(Port, "/empty"), @@ -188,6 +197,33 @@ empty_get() -> ?assertEqual({200, "OK"}, status(Response)), ?assertEqual(<<>>, body(Response)). +basic_auth() -> + User = "foo", + Passwd = "bar", + Port = start(gen_tcp, [basic_auth_responder(User, Passwd)]), + URL = url(Port, "/empty", User, Passwd), + {ok, Response} = lhttpc:request(URL, "GET", [], 1000), + ?assertEqual({200, "OK"}, status(Response)), + ?assertEqual(<<"OK">>, body(Response)). + +missing_basic_auth() -> + User = "foo", + Passwd = "bar", + Port = start(gen_tcp, [basic_auth_responder(User, Passwd)]), + URL = url(Port, "/empty"), + {ok, Response} = lhttpc:request(URL, "GET", [], 1000), + ?assertEqual({401, "Unauthorized"}, status(Response)), + ?assertEqual(<<"missing_auth">>, body(Response)). + +wrong_basic_auth() -> + User = "foo", + Passwd = "bar", + Port = start(gen_tcp, [basic_auth_responder(User, Passwd)]), + URL = url(Port, "/empty", User, "wrong_password"), + {ok, Response} = lhttpc:request(URL, "GET", [], 1000), + ?assertEqual({401, "Unauthorized"}, status(Response)), + ?assertEqual(<<"wrong_auth">>, body(Response)). + get_with_mandatory_hdrs() -> Port = start(gen_tcp, [fun simple_response/5]), URL = url(Port, "/host"), @@ -382,14 +418,14 @@ request_timeout() -> connection_timeout() -> Port = start(gen_tcp, [fun simple_response/5, fun simple_response/5]), URL = url(Port, "/close_conn"), - lhttpc_manager:update_connection_timeout(50), % very short keep alive + lhttpc_manager:update_connection_timeout(lhttpc_manager, 50), % very short keep alive {ok, Response} = lhttpc:request(URL, get, [], 100), ?assertEqual({200, "OK"}, status(Response)), ?assertEqual(<>, body(Response)), timer:sleep(100), ?assertEqual(0, - lhttpc_manager:connection_count({"localhost", Port, false})), - lhttpc_manager:update_connection_timeout(300000). % set back + lhttpc_manager:connection_count(lhttpc_manager, {"localhost", Port, false})), + lhttpc_manager:update_connection_timeout(lhttpc_manager, 300000). % set back suspended_manager() -> Port = start(gen_tcp, [fun simple_response/5, fun simple_response/5]), @@ -402,7 +438,7 @@ suspended_manager() -> ?assertEqual({error, timeout}, lhttpc:request(URL, get, [], 50)), true = erlang:resume_process(Pid), ?assertEqual(1, - lhttpc_manager:connection_count({"localhost", Port, false})), + lhttpc_manager:connection_count(lhttpc_manager, {"localhost", Port, false})), {ok, SecondResponse} = lhttpc:request(URL, get, [], 50), ?assertEqual({200, "OK"}, status(SecondResponse)), ?assertEqual(<>, body(SecondResponse)). @@ -651,6 +687,13 @@ ssl_get() -> ?assertEqual({200, "OK"}, status(Response)), ?assertEqual(<>, body(Response)). +ssl_get_ipv6() -> + Port = start(ssl, [fun simple_response/5], inet6), + URL = ssl_url(inet6, Port, "/simple"), + {ok, Response} = lhttpc:request(URL, "GET", [], 1000), + ?assertEqual({200, "OK"}, status(Response)), + ?assertEqual(<>, body(Response)). + ssl_post() -> Port = start(ssl, [fun copy_body/5]), URL = ssl_url(Port, "/simple"), @@ -683,7 +726,7 @@ ssl_chunked() -> connection_count() -> timer:sleep(50), % give the TCP stack time to deliver messages - ?assertEqual(0, lhttpc_manager:connection_count()). + ?assertEqual(0, lhttpc_manager:connection_count(lhttpc_manager)). invalid_options() -> ?assertError({bad_options, [{foo, bar}, bad_option]}, @@ -717,20 +760,46 @@ read_partial_body(Pid, Size, Acc) -> end. simple(Method) -> - Port = start(gen_tcp, [fun simple_response/5]), - URL = url(Port, "/simple"), - {ok, Response} = lhttpc:request(URL, Method, [], 1000), - {StatusCode, ReasonPhrase} = status(Response), - ?assertEqual(200, StatusCode), - ?assertEqual("OK", ReasonPhrase), - ?assertEqual(<>, body(Response)). + simple(Method, inet). + +simple(Method, Family) -> + case start(gen_tcp, [fun simple_response/5], Family) of + {error, family_not_supported} when Family =:= inet6 -> + % Localhost has no IPv6 support - not a big issue. + ?debugMsg("WARNING: impossible to test IPv6 support~n"); + Port when is_number(Port) -> + URL = url(Family, Port, "/simple"), + {ok, Response} = lhttpc:request(URL, Method, [], 1000), + {StatusCode, ReasonPhrase} = status(Response), + ?assertEqual(200, StatusCode), + ?assertEqual("OK", ReasonPhrase), + ?assertEqual(<>, body(Response)) + end. url(Port, Path) -> - "http://localhost:" ++ integer_to_list(Port) ++ Path. + url(inet, Port, Path). + +url(inet, Port, Path) -> + "http://localhost:" ++ integer_to_list(Port) ++ Path; +url(inet6, Port, Path) -> + "http://[::1]:" ++ integer_to_list(Port) ++ Path. + +url(Port, Path, User, Password) -> + url(inet, Port, Path, User, Password). + +url(inet, Port, Path, User, Password) -> + "http://" ++ User ++ ":" ++ Password ++ + "@localhost:" ++ integer_to_list(Port) ++ Path; +url(inet6, Port, Path, User, Password) -> + "http://" ++ User ++ ":" ++ Password ++ + "@[::1]:" ++ integer_to_list(Port) ++ Path. ssl_url(Port, Path) -> "https://localhost:" ++ integer_to_list(Port) ++ Path. +ssl_url(inet6, Port, Path) -> + "https://[::1]:" ++ integer_to_list(Port) ++ Path. + status({Status, _, _}) -> Status. @@ -983,3 +1052,44 @@ not_modified_response(Module, Socket, _Request, _Headers, _Body) -> "Date: Tue, 15 Nov 1994 08:12:31 GMT\r\n\r\n" ] ). + +basic_auth_responder(User, Passwd) -> + fun(Module, Socket, _Request, Headers, _Body) -> + case proplists:get_value("Authorization", Headers) of + undefined -> + Module:send( + Socket, + [ + "HTTP/1.1 401 Unauthorized\r\n", + "Content-Type: text/plain\r\n", + "Content-Length: 12\r\n\r\n", + "missing_auth" + ] + ); + "Basic " ++ Auth -> + [U, P] = string:tokens( + binary_to_list(base64:decode(iolist_to_binary(Auth))), ":"), + case {U, P} of + {User, Passwd} -> + Module:send( + Socket, + [ + "HTTP/1.1 200 OK\r\n", + "Content-Type: text/plain\r\n", + "Content-Length: 2\r\n\r\n", + "OK" + ] + ); + _ -> + Module:send( + Socket, + [ + "HTTP/1.1 401 Unauthorized\r\n", + "Content-Type: text/plain\r\n", + "Content-Length: 10\r\n\r\n", + "wrong_auth" + ] + ) + end + end + end. diff --git a/test/socket_server.erl b/test/socket_server.erl index 68cb0a8..15fdcd8 100644 --- a/test/socket_server.erl +++ b/test/socket_server.erl @@ -27,33 +27,23 @@ %%% @author Oscar Hellström -module(socket_server). --export([open/0, connect/2, listen/0, accept/1]). +-export([connect/1, listen/0, accept/1]). -export([do_accept/1]). -open() -> - {LS, Port} = listen(), - Pid = accept(LS), - {ok, Port} = inet:port(LS), - Pid ! {connecting, self()}, - {ok, Socket} = gen_tcp:connect({127,0,0,1}, Port, [{active, false}]), - receive accepted -> ok end, - {LS, Socket}. -connect(Pid, Port) -> - Pid ! {connecting, self()}, - {ok, Socket} = gen_tcp:connect({127,0,0,1}, Port, [{active, false}]), - receive accepted -> ok end, +connect(Port) -> + {ok, Socket} = gen_tcp:connect({127,0,0,1}, Port, [{active, false}, binary]), Socket. listen() -> - {ok, LS} = gen_tcp:listen(0, [{active, false}, {ip, {127,0,0,1}}]), + {ok, LS} = gen_tcp:listen(0, [{active, false}, {ip, {127,0,0,1}}, binary]), {ok, Port} = inet:port(LS), - {LS, Port}. + ok = application:set_env(lhttpc, test_port, Port), + LS. accept(LS) -> spawn_link(?MODULE, do_accept, [LS]). do_accept(LS) -> {ok, S} = gen_tcp:accept(LS), - receive {connecting, Pid} -> Pid ! accepted end, {error, closed} = gen_tcp:recv(S, 0). diff --git a/test/webserver.erl b/test/webserver.erl index 70107f0..7da8cc6 100644 --- a/test/webserver.erl +++ b/test/webserver.erl @@ -30,13 +30,21 @@ %%% @end -module(webserver). --export([start/2, read_chunked/3]). +-export([start/2, start/3, read_chunked/3]). -export([accept_connection/4]). start(Module, Responders) -> - LS = listen(Module), - spawn_link(?MODULE, accept_connection, [self(), Module, LS, Responders]), - port(Module, LS). + start(Module, Responders, inet). + +start(Module, Responders, Family) -> + case get_addr("localhost", Family) of + {ok, Addr} -> + LS = listen(Module, Addr, Family), + spawn_link(?MODULE, accept_connection, [self(), Module, LS, Responders]), + port(Module, LS); + Error -> + Error + end. accept_connection(Parent, Module, ListenSocket, Responders) -> Socket = accept(Module, ListenSocket), @@ -111,27 +119,37 @@ server_loop(Module, Socket, Request, Headers, Responders) -> Module:close(Socket) end. -listen(ssl) -> +listen(ssl, Addr, Family) -> Opts = [ + Family, {packet, http}, binary, {active, false}, - {ip, {127,0,0,1}}, + {ip, Addr}, {verify,0}, {keyfile, "../test/key.pem"}, {certfile, "../test/crt.pem"} ], {ok, LS} = ssl:listen(0, Opts), LS; -listen(Module) -> +listen(Module, Addr, Family) -> {ok, LS} = Module:listen(0, [ + Family, {packet, http}, binary, {active, false}, - {ip, {127,0,0,1}} + {ip, Addr} ]), LS. +get_addr(Host, Family) -> + case inet:getaddr(Host, Family) of + {ok, Addr} -> + {ok, Addr}; + _ -> + {error, family_not_supported} + end. + accept(ssl, ListenSocket) -> {ok, Socket} = ssl:transport_accept(ListenSocket, 10000), ok = ssl:ssl_accept(Socket),