diff --git a/src/erf_router.erl b/src/erf_router.erl index cee7ee7..4f0372c 100644 --- a/src/erf_router.erl +++ b/src/erf_router.erl @@ -249,10 +249,13 @@ handle_ast(API, #{callback := Callback} = Opts) -> Request ), + Responses = maps:get(responses, Operation, undefined), + ValidateResponseAST = validate_response(Responses), + erl_syntax:clause( [ erl_syntax:match_expr( - erl_syntax:variable('Request'), + erl_syntax:variable('Request0'), erl_syntax:map_expr( none, [ @@ -286,6 +289,18 @@ handle_ast(API, #{callback := Callback} = Opts) -> erl_syntax:variable('IsValidRequest'), IsValidRequestAST ), + erl_syntax:match_expr( + erl_syntax:variable('Request'), + erl_syntax:map_expr( + erl_syntax:variable('Request0'), + [ + erl_syntax:map_field_assoc( + erl_syntax:atom('path_parameters'), + erl_syntax:variable('PathParameters') + ) + ] + ) + ), erl_syntax:case_expr( erl_syntax:variable('IsValidRequest'), [ @@ -293,7 +308,7 @@ handle_ast(API, #{callback := Callback} = Opts) -> [erl_syntax:atom(true)], none, [ - erl_syntax:application( + erl_syntax:match_expr(erl_syntax:variable('Response'), erl_syntax:application( erl_syntax:atom(Callback), erl_syntax:atom( erlang:binary_to_atom( @@ -303,36 +318,26 @@ handle_ast(API, #{callback := Callback} = Opts) -> utf8 ) ), - [ - erl_syntax:map_expr( - erl_syntax:variable('Request'), - [ - erl_syntax:map_field_assoc( - erl_syntax:atom('path_parameters'), - erl_syntax:variable( - 'PathParameters' - ) - ) - ] - ) - ] - ) + [erl_syntax:variable('Request')] + )), + ValidateResponseAST ] ), erl_syntax:clause( [ erl_syntax:tuple([ erl_syntax:atom(false), - erl_syntax:variable('_Reason') + erl_syntax:variable('Reason') ]) ], none, [ - erl_syntax:tuple( + erl_syntax:application( + erl_syntax:atom(erf_util), + erl_syntax:atom(handle_invalid_request), [ - erl_syntax:integer(400), - erl_syntax:list([]), - erl_syntax:atom(undefined) + erl_syntax:variable('Request'), + erl_syntax:variable('Reason') ] ) ] @@ -691,6 +696,33 @@ is_valid_request(RawParameters, Request) -> ] ). +-spec validate_response(Responses) -> Result when + Responses :: undefined | #{ '*' | erf_parser:status_code() := erf_parser:response() }, + Result :: erl_syntax:syntaxTree(). +validate_response(undefined) -> + erl_syntax:atom('ok'); +validate_response(#{} = Responses) -> + Validators = erl_syntax:map_expr( + none, + [response_validation_map_assoc(R) || R <- maps:to_list(Responses)] + ), + erl_syntax:application( + erl_syntax:atom(erf_util), + erl_syntax:atom(validate_response), + [erl_syntax:variable('Request'), erl_syntax:variable('Response'), Validators] + ). + +response_validation_map_assoc({Status, #{body := #{ref := ResponseBodyRef}}}) -> + StatusAST = case Status of + '*' -> erl_syntax:atom('*'); + _ when is_integer(Status) -> erl_syntax:integer(Status) + end, + ResponseBodyModule = + erlang:binary_to_atom(erf_util:to_snake_case(ResponseBodyRef)), + erl_syntax:map_field_assoc(StatusAST, erl_syntax:atom(ResponseBodyModule)). + + + -spec load_binary(ModuleName, Bin) -> Result when ModuleName :: atom(), Bin :: binary(), diff --git a/src/erf_util.erl b/src/erf_util.erl index 76c381b..e8ff88e 100644 --- a/src/erf_util.erl +++ b/src/erf_util.erl @@ -13,10 +13,14 @@ %% limitations under the License -module(erf_util). +-include_lib("kernel/include/logger.hrl"). + %%% EXTERNAL EXPORTS -export([ to_pascal_case/1, - to_snake_case/1 + to_snake_case/1, + handle_invalid_request/2, + validate_response/3 ]). %%%----------------------------------------------------------------------------- @@ -86,3 +90,26 @@ to_snake_case([_C | Rest], [$_ | _T] = Acc) -> to_snake_case(Rest, Acc); to_snake_case([_C | Rest], Acc) -> to_snake_case(Rest, [$_ | Acc]). + + +handle_invalid_request(_Request, Reason) -> + {400, [{<<"content-type">>, <<"text/plain">>}], [io_lib:print(Reason), "\n"]}. + +validate_response(Request, {Status, _, #{} = RespBody} = Response, Validators) -> + IsValid = case Validators of + #{Status := ValidatorMod} -> + ValidatorMod:is_valid(RespBody); + #{'*' := ValidatorMod} -> + ValidatorMod:is_valid(RespBody); + #{} -> + unknown_status + end, + case IsValid of + true -> + Response; + Other -> + ?LOG(error, "[erf] Bad response (~0p)~n ~120p~n for request ~120p", [Other, Response, Request]), + {500, [{<<"content-type">>, <<"text/plain">>}], <<"Internal server error">>} + end; +validate_response(_Request, {_, _, _} = Response, _) -> + Response. diff --git a/test/erf_SUITE.erl b/test/erf_SUITE.erl index d0ef7ac..7f703da 100644 --- a/test/erf_SUITE.erl +++ b/test/erf_SUITE.erl @@ -101,7 +101,7 @@ foo(_Conf) -> ), ?assertMatch( - {ok, {{"HTTP/1.1", 400, "Bad Request"}, _Result2Headers, <<>>}}, + {ok, {{"HTTP/1.1", 400, "Bad Request"}, _Result2Headers, <<_/binary>>}}, httpc:request( post, {"http://localhost:8789/1/foo", [], "application/json", <<"\"foobar\"">>}, diff --git a/test/erf_router_SUITE.erl b/test/erf_router_SUITE.erl index 840f1ac..367b501 100644 --- a/test/erf_router_SUITE.erl +++ b/test/erf_router_SUITE.erl @@ -166,7 +166,7 @@ foo(_Conf) -> meck:expect(get_foo_request_body, is_valid, fun(_Value) -> {false, reason} end), - ?assertEqual({400, [], undefined}, Mod:handle(Req)), + ?assertMatch({400, [{<<"content-type">>, _}], _}, Mod:handle(Req)), NotAllowedReq = #{ path => [<<"1">>, <<"foo">>],