diff --git a/Project.toml b/Project.toml index be404c1..d63e0e5 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" ReadStatTables = "52522f7a-9570-4e34-8ac6-c005c74d4b84" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/scope.jl b/scope.jl new file mode 100644 index 0000000..a3ee34e --- /dev/null +++ b/scope.jl @@ -0,0 +1,41 @@ +import ScopedValues + +const MyValue = ScopedValues.ScopedValue{Int}() + +function step1() + return MyValue[] + 1 +end + +function step2() + return MyValue[] * 2 +end + +function step3() + return MyValue[] - 3 +end + +function process_steps(steps) + if isempty(steps) + return MyValue[] + else + current_step = first(steps) + remaining_steps = steps[2:end] + new_value = current_step() + println("After $(current_step): ", new_value) + + ScopedValues.@with MyValue => new_value begin + process_steps(remaining_steps) + end + end +end + +function process_chain(initial_value, steps) + ScopedValues.@with MyValue => initial_value begin + println("Initial: ", MyValue[]) + process_steps(steps) + end +end + +steps = [step1, step2, step3] +result = process_chain(0, steps) +println("Final result: ", result) \ No newline at end of file diff --git a/src/Kezdi.jl b/src/Kezdi.jl index 14eaf60..30d2df9 100644 --- a/src/Kezdi.jl +++ b/src/Kezdi.jl @@ -10,6 +10,7 @@ using Reexport using Logging using InteractiveUtils using ReadStatTables +import ScopedValues @reexport using FreqTables: freqtable @reexport using FixedEffectModels @@ -31,5 +32,8 @@ include("side_effects.jl") include("With.jl") @reexport using .With: @with, @with! +runtime_context = ScopedValues.ScopedValue(RuntimeContext(nothing, true)) +compile_context = ScopedValues.ScopedValue(CompileContext()) +global_runtime_context = RuntimeContext(nothing, true) end # module diff --git a/src/With.jl b/src/With.jl index 5ed420e..2d1a45f 100644 --- a/src/With.jl +++ b/src/With.jl @@ -1,30 +1,163 @@ module With -using ..Kezdi export @with, @with! +using ..Kezdi -""" - @with df begin - # do something with df +is_aside(x) = false +function is_aside(x::Expr)::Bool + if x.head == :(=) + return is_aside(x.args[2]) end + return x.head == :macrocall && Symbol(String(x.args[1])[2:end]) in Kezdi.SIDE_EFFECTS +end -The `@with` macro is a convenience macro that allows you to set the current data frame and perform operations on it in a single block. The first argument is the data frame to set as the current data frame, and the second argument is a block of code to execute. The data frame is set as the current data frame for the duration of the block, and then restored to its previous value after the block is executed. -The macro returns the value of the last expression in the block. -""" -macro with(initial_value, args...) - block = flatten_to_single_block(initial_value, args...) +function call_with_context(e::Expr, firstarg; assignment = false) + head = e.head + args = e.args + # set assignment = true and rerun with right hand side + if !assignment && head == :(=) && length(args) == 2 + if !(args[1] isa Symbol) + error("You can only use assignment syntax with a Symbol as a variable name, not $(args[1]).") + end + variable = args[1] + righthandside = call_with_context(args[2], firstarg; assignment = true) + return :($variable = $righthandside) + end + :(Kezdi.ScopedValues.@with Kezdi.compile_context => $(Kezdi.get_compile_context()) Kezdi.runtime_context => Kezdi.RuntimeContext($firstarg) $e) +end + +function rewrite(expr, replacement) + aside = is_aside(expr) + new_expr = call_with_context(expr, replacement) + replacement = gensym() + new_expr = :(local $replacement = $new_expr) + + (new_expr, replacement, aside) +end + +rewrite(l::LineNumberNode, replacement) = (l, replacement, true) + +function rewrite_with_block(firstpart, block) + pushfirst!(block.args, firstpart) rewrite_with_block(block) end """ - @with! df begin - # do something with df - end + @with(expr, exprs...) + +Rewrites a series of expressions into a with, where the result of one expression +is inserted into the next expression following certain rules. + +**Rule 1** + +Any `expr` that is a `begin ... end` block is flattened. +For example, these two pseudocodes are equivalent: + +```julia +@with a b c d e f + +@with a begin + b + c + d +end e f +``` + +**Rule 2** + +Any expression but the first (in the flattened representation) will have the preceding result +inserted as its first argument, unless at least one underscore `_` is present. +In that case, all underscores will be replaced with the preceding result. + +If the expression is a symbol, the symbol is treated equivalently to a function call. + +For example, the following code block + +```julia +@with begin + x + f() + @g() + h + @i + j(123, _) + k(_, 123, _) +end +``` + +is equivalent to -The `@with!` macro is a convenience macro that allows you to set the current data frame and perform operations on it in a single block. The first argument is the data frame to set as the current data frame, and the second argument is a block of code to execute. The data frame is set as the current data frame for the duration of the block, and then restored to its previous value after the block is executed. +```julia +begin + local temp1 = f(x) + local temp2 = @g(temp1) + local temp3 = h(temp2) + local temp4 = @i(temp3) + local temp5 = j(123, temp4) + local temp6 = k(temp5, 123, temp5) +end +``` + +**Rule 3** + +An expression that begins with `@aside` does not pass its result on to the following expression. +Instead, the result of the previous expression will be passed on. +This is meant for inspecting the state of the with. +The expression within `@aside` will not get the previous result auto-inserted, you can use +underscores to reference it. + +```julia +@with begin + [1, 2, 3] + filter(isodd, _) + @aside @info "There are \$(length(_)) elements after filtering" + sum +end +``` + +**Rule 4** + +It is allowed to start an expression with a variable assignment. +In this case, the usual insertion rules apply to the right-hand side of that assignment. +This can be used to store intermediate results. + +```julia +@with begin + [1, 2, 3] + filtered = filter(isodd, _) + sum +end + +filtered == [1, 3] +``` + +**Rule 5** + +The `@.` macro may be used with a symbol to broadcast that function over the preceding result. + +```julia +@with begin + [1, 2, 3] + @. sqrt +end +``` + +is equivalent to + +```julia +@with begin + [1, 2, 3] + sqrt.(_) +end +``` -The macro does not have a return value, it overwrites the data frame directly. """ +macro with(initial_value, args...) + block = flatten_to_single_block(initial_value, args...) + rewrite_with_block(block) +end + + macro with!(initial_value, args...) block = flatten_to_single_block(initial_value, args...) result = rewrite_with_block(block) @@ -44,6 +177,8 @@ function flatten_to_single_block(args...) end function rewrite_with_block(block) + current_context = Kezdi.get_compile_context() + local line_number = current_context.line_number block_expressions = block.args isempty(block_expressions) || (length(block_expressions) == 1 && block_expressions[] isa LineNumberNode) && @@ -51,31 +186,37 @@ function rewrite_with_block(block) reconvert_docstrings!(block_expressions) - # save current dataframe - previous_df = gensym() + local_value = gensym() + replaced_value = local_value + current_df = local_value rewritten_exprs = [] did_first = false for expr in block_expressions # could be an expression first or a LineNumberNode, so a bit convoluted - # we just do the firstvar transformation for the first non LineNumberNode + # we just do the local_context transformation for the first non LineNumberNode # we encounter if !(did_first || expr isa LineNumberNode) + expr = :(local $local_value = $expr) did_first = true - push!(rewritten_exprs, :(local $previous_df = getdf())) - push!(rewritten_exprs, :(setdf($expr))) + push!(rewritten_exprs, expr) continue end - - push!(rewritten_exprs, expr) + + if expr isa LineNumberNode + line_number = expr.line + end + + rewritten, replaced_value, aside = Kezdi.ScopedValues.@with Kezdi.compile_context => Kezdi.CompileContext(current_context.scalars, current_context.flags, true, line_number) rewrite(expr, current_df) + push!(rewritten_exprs, rewritten) + if !aside + push!(rewritten_exprs, :(local $current_df = $replaced_value)) + end end - teardown = :(x -> begin - setdf($previous_df) - x - end) - result = Expr(:block, rewritten_exprs...) + + result = Expr(:block, rewritten_exprs..., replaced_value) - :($(esc(result)) |> $(esc(teardown))) + :($(esc(result))) end # if a line in a with is a string, it can be parsed as a docstring @@ -98,4 +239,4 @@ function reconvert_docstrings!(args::Vector) args end -end +end \ No newline at end of file diff --git a/src/codegen.jl b/src/codegen.jl index f09ba83..e85b675 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -2,6 +2,7 @@ function generate_command(command::Command; options=[], allowed=[]) df2 = gensym() sdf = gensym() gdf = gensym() + context = gensym() setup = Expr[] teardown = Expr[] process = (x -> x) @@ -10,6 +11,9 @@ function generate_command(command::Command; options=[], allowed=[]) target_df = df2 given_options = get_top_symbol.(command.options) + current_context = Kezdi.get_compile_context() + @warn current_context + current_context.with_block && @warn "I am in a with block, line number is $(current_context.line_number)" # check for syntax if !(:ifable in options) && !isnothing(command.condition) @@ -28,8 +32,11 @@ function generate_command(command::Command; options=[], allowed=[]) (opt in allowed) || ArgumentError("Invalid option \"$opt\" for this command: @$(command.command)") |> throw end - push!(setup, :(getdf() isa AbstractDataFrame || error("Kezdi.jl commands can only operate on a global DataFrame set by setdf()"))) - push!(setup, :(local $df2 = copy(getdf()))) + push!(setup, quote + local $context = Kezdi.get_runtime_context() + $context.df isa AbstractDataFrame || error("Kezdi.jl commands can only operate on a DataFrame") + local $df2 = copy($context.df) + end) variables_condition = (:ifable in options) ? vcat(extract_variable_references(command.condition)...) : Symbol[] variables_RHS = (:variables in options) ? vcat(extract_variable_references.(command.arguments)...) : Symbol[] variables = vcat(variables_condition, variables_RHS) @@ -71,6 +78,8 @@ function generate_command(command::Command; options=[], allowed=[]) end push!(setup, quote function $tdfunction(x) + # add global dataframe save here + $context.inplace && setdf($target_df) $(Expr(:block, teardown...)) x end @@ -89,7 +98,6 @@ function get_option(command::Command, key::Symbol) end end - function get_top_symbol(expr::Any) if expr isa Expr return get_top_symbol(expr.args[1]) diff --git a/src/commands.jl b/src/commands.jl index 6eef5a2..83b5d5c 100644 --- a/src/commands.jl +++ b/src/commands.jl @@ -9,7 +9,7 @@ function rewrite(::Val{:rename}, command::Command) ArgumentError("Syntax is @rename oldname newname") |> throw else $setup - rename!($local_copy, $arguments[1] => $arguments[2]) |> $teardown |> setdf + rename!($local_copy, $arguments[1] => $arguments[2]) |> $teardown end end |> esc end @@ -26,7 +26,7 @@ function rewrite(::Val{:generate}, command::Command) $setup $local_copy[!, $target_column] .= missing $target_df[!, $target_column] .= $RHS - $local_copy |> $teardown |> setdf + $local_copy |> $teardown end end |> esc end @@ -53,7 +53,7 @@ function rewrite(::Val{:replace}, command::Command) else $target_df[!, $target_column] .= $RHS end - $local_copy |> $teardown |> setdf + $local_copy |> $teardown end end |> esc end @@ -63,7 +63,7 @@ function rewrite(::Val{:keep}, command::Command) (; local_copy, target_df, setup, teardown, arguments, options) = gc quote $setup - $target_df[!, isempty($(command.arguments)) ? eval(:(:)) : collect($command.arguments)] |> $teardown |> setdf + $target_df[!, isempty($(command.arguments)) ? eval(:(:)) : collect($command.arguments)] |> $teardown end |> esc end @@ -73,13 +73,13 @@ function rewrite(::Val{:drop}, command::Command) if isnothing(command.condition) return quote $setup - select($local_copy, Not(collect($(command.arguments)))) |> $teardown |> setdf + select($local_copy, Not(collect($(command.arguments)))) |> $teardown end |> esc end bitmask = build_bitmask(local_copy, command.condition) return quote $setup - $local_copy[.!($bitmask), :] |> $teardown |> setdf + $local_copy[.!($bitmask), :] |> $teardown end |> esc end @@ -89,7 +89,7 @@ function rewrite(::Val{:collapse}, command::Command) combine_epxression = Expr(:call, :combine, target_df, build_assignment_formula.(command.arguments)...) quote $setup - $combine_epxression |> $teardown |> setdf + $combine_epxression |> $teardown end |> esc end @@ -104,7 +104,7 @@ function rewrite(::Val{:egen}, command::Command) else $setup $transform_expression - $local_copy |> $teardown |> setdf + $local_copy |> $teardown end end |> esc end @@ -116,7 +116,7 @@ function rewrite(::Val{:sort}, command::Command) desc = :desc in get_top_symbol.(options) ? true : false quote $setup - sort($target_df, $columns, rev=$desc) |> $teardown |> setdf + sort($target_df, $columns, rev=$desc) |> $teardown end |> esc end @@ -180,7 +180,7 @@ function rewrite(::Val{:order}, command::Command) cols = pushfirst!(cols, target_cols...) end - $target_df[!,cols]|> $teardown |> setdf + $target_df[!,cols]|> $teardown end |> esc end diff --git a/src/consts.jl b/src/consts.jl index ed62404..3a1feed 100644 --- a/src/consts.jl +++ b/src/consts.jl @@ -72,27 +72,15 @@ const TYPES = ( :Vector ) -const COMMANDS = ( - :keep, - :drop, - :generate, - :replace, - :egen, - :collapse, - :tabulate, - :summarize, - :regress -) - const SIDE_EFFECTS = ( - Symbol("@tabulate"), - Symbol("@summarize"), - Symbol("@regress"), - Symbol("@list"), - Symbol("@head"), - Symbol("@tail"), - Symbol("@names"), - Symbol("@count") + :tabulate, + :summarize, + :regress, + :list, + :head, + :tail, + :names, + :count ) const DO_NOT_VECTORIZE = ( @@ -116,11 +104,11 @@ const OPTIONS = ( :variables ) +const DEFAULT_FLAGS = Set{Symbol}() + const SYNTACTIC_OPERATORS = tuple([Symbol(x) for x in split(raw"&& || += -= *= /= //= \= ^= ÷= %= <<= >>= >>>= |= &= ⊻=")]...) const OPERATORS = tuple( vcat( [Symbol(x) for x in split(raw"= += -= −= *= /= //= \= ^= ÷= %= <<= >>= >>>= |= &= ⊻= ≔ ⩴ ≕ ← → ↔ ↚ ↛ ↞ ↠ ↢ ↣ ↦ ↤ ↮ ⇎ ⇍ ⇏ ⇐ ⇒ ⇔ ⇴ ⇶ ⇷ ⇸ ⇹ ⇺ ⇻ ⇼ ⇽ ⇾ ⇿ ⟵ ⟶ ⟷ ⟹ ⟺ ⟻ ⟼ ⟽ ⟾ ⟿ ⤀ ⤁ ⤂ ⤃ ⤄ ⤅ ⤆ ⤇ ⤌ ⤍ ⤎ ⤏ ⤐ ⤑ ⤔ ⤕ ⤖ ⤗ ⤘ ⤝ ⤞ ⤟ ⤠ ⥄ ⥅ ⥆ ⥇ ⥈ ⥊ ⥋ ⥎ ⥐ ⥒ ⥓ ⥖ ⥗ ⥚ ⥛ ⥞ ⥟ ⥢ ⥤ ⥦ ⥧ ⥨ ⥩ ⥪ ⥫ ⥬ ⥭ ⥰ ⧴ ⬱ ⬰ ⬲ ⬳ ⬴ ⬵ ⬶ ⬷ ⬸ ⬹ ⬺ ⬻ ⬼ ⬽ ⬾ ⬿ ⭀ ⭁ ⭂ ⭃ ⥷ ⭄ ⥺ ⭇ ⭈ ⭉ ⭊ ⭋ ⭌ ← → ⇜ ⇝ ↜ ↝ ↩ ↪ ↫ ↬ ↼ ↽ ⇀ ⇁ ⇄ ⇆ ⇇ ⇉ ⇋ ⇌ ⇚ ⇛ ⇠ ⇢ ↷ ↶ ↺ ↻ ~ --> <-- <--> > < >= ≥ <= ≤ == === ≡ != ≠ !== ≢ ∈ ∉ ∋ ∌ ⊆ ⊈ ⊂ ⊄ ⊊ ∝ ∊ ∍ ∥ ∦ ∷ ∺ ∻ ∽ ∾ ≁ ≃ ≂ ≄ ≅ ≆ ≇ ≈ ≉ ≊ ≋ ≌ ≍ ≎ ≐ ≑ ≒ ≓ ≖ ≗ ≘ ≙ ≚ ≛ ≜ ≝ ≞ ≟ ≣ ≦ ≧ ≨ ≩ ≪ ≫ ≬ ≭ ≮ ≯ ≰ ≱ ≲ ≳ ≴ ≵ ≶ ≷ ≸ ≹ ≺ ≻ ≼ ≽ ≾ ≿ ⊀ ⊁ ⊃ ⊅ ⊇ ⊉ ⊋ ⊏ ⊐ ⊑ ⊒ ⊜ ⊩ ⊬ ⊮ ⊰ ⊱ ⊲ ⊳ ⊴ ⊵ ⊶ ⊷ ⋍ ⋐ ⋑ ⋕ ⋖ ⋗ ⋘ ⋙ ⋚ ⋛ ⋜ ⋝ ⋞ ⋟ ⋠ ⋡ ⋢ ⋣ ⋤ ⋥ ⋦ ⋧ ⋨ ⋩ ⋪ ⋫ ⋬ ⋭ ⋲ ⋳ ⋴ ⋵ ⋶ ⋷ ⋸ ⋹ ⋺ ⋻ ⋼ ⋽ ⋾ ⋿ ⟈ ⟉ ⟒ ⦷ ⧀ ⧁ ⧡ ⧣ ⧤ ⧥ ⩦ ⩧ ⩪ ⩫ ⩬ ⩭ ⩮ ⩯ ⩰ ⩱ ⩲ ⩳ ⩵ ⩶ ⩷ ⩸ ⩹ ⩺ ⩻ ⩼ ⩽ ⩾ ⩿ ⪀ ⪁ ⪂ ⪃ ⪄ ⪅ ⪆ ⪇ ⪈ ⪉ ⪊ ⪋ ⪌ ⪍ ⪎ ⪏ ⪐ ⪑ ⪒ ⪓ ⪔ ⪕ ⪖ ⪗ ⪘ ⪙ ⪚ ⪛ ⪜ ⪝ ⪞ ⪟ ⪠ ⪡ ⪢ ⪣ ⪤ ⪥ ⪦ ⪧ ⪨ ⪩ ⪪ ⪫ ⪬ ⪭ ⪮ ⪯ ⪰ ⪱ ⪲ ⪳ ⪴ ⪵ ⪶ ⪷ ⪸ ⪹ ⪺ ⪻ ⪼ ⪽ ⪾ ⪿ ⫀ ⫁ ⫂ ⫃ ⫄ ⫅ ⫆ ⫇ ⫈ ⫉ ⫊ ⫋ ⫌ ⫍ ⫎ ⫏ ⫐ ⫑ ⫒ ⫓ ⫔ ⫕ ⫖ ⫗ ⫘ ⫙ ⫷ ⫸ ⫹ ⫺ ⊢ ⊣ ⟂ ⫪ ⫫ <: >: + - − ¦ | ⊕ ⊖ ⊞ ⊟ ++ ∪ ∨ ⊔ ± ∓ ∔ ∸ ≏ ⊎ ⊻ ⊽ ⋎ ⋓ ⟇ ⧺ ⧻ ⨈ ⨢ ⨣ ⨤ ⨥ ⨦ ⨧ ⨨ ⨩ ⨪ ⨫ ⨬ ⨭ ⨮ ⨹ ⨺ ⩁ ⩂ ⩅ ⩊ ⩌ ⩏ ⩐ ⩒ ⩔ ⩖ ⩗ ⩛ ⩝ ⩡ ⩢ ⩣ * / ⌿ ÷ % & · · ⋅ ∘ × \ ∩ ∧ ⊗ ⊘ ⊙ ⊚ ⊛ ⊠ ⊡ ⊓ ∗ ∙ ∤ ⅋ ≀ ⊼ ⋄ ⋆ ⋇ ⋉ ⋊ ⋋ ⋌ ⋏ ⋒ ⟑ ⦸ ⦼ ⦾ ⦿ ⧶ ⧷ ⨇ ⨰ ⨱ ⨲ ⨳ ⨴ ⨵ ⨶ ⨷ ⨸ ⨻ ⨼ ⨽ ⩀ ⩃ ⩄ ⩋ ⩍ ⩎ ⩑ ⩓ ⩕ ⩘ ⩚ ⩜ ⩞ ⩟ ⩠ ⫛ ⊍ ▷ ⨝ ⟕ ⟖ ⟗ ⨟ // ^ ↑ ↓ ⇵ ⟰ ⟱ ⤈ ⤉ ⤊ ⤋ ⤒ ⤓ ⥉ ⥌ ⥍ ⥏ ⥑ ⥔ ⥕ ⥘ ⥙ ⥜ ⥝ ⥠ ⥡ ⥣ ⥥ ⥮ ⥯ ↑ ↓ << >> >>>")], SYNTACTIC_OPERATORS...)...) -# not really a const, but anyway -global _global_dataframe::Union{AbstractDataFrame, Nothing} = nothing diff --git a/src/functions.jl b/src/functions.jl index 2e92790..b5bfa15 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -1,18 +1,21 @@ use(fname::AbstractString) = readstat(fname) |> DataFrame |> setdf """ - getdf() -> AbstractDataFrame + get_runtime_context() -> RuntimeContext -Return the global data frame. +Return the current runtime context. This can be passed on as a ScopedValue or set as a global. """ -getdf() = _global_dataframe +get_runtime_context() = Kezdi.runtime_context[].df isa Nothing ? Kezdi.global_runtime_context : Kezdi.runtime_context[] +getdf() = get_runtime_context().df +get_compile_context() = Kezdi.compile_context[] """ - setdf(df::Union{AbstractDataFrame, Nothing}) + setdf(df::AbstractDataFrame) -Set the global data frame. +Set the data frame in the global scope. """ -setdf(df::Union{AbstractDataFrame, Nothing}) = global _global_dataframe = df +setdf(df::Union{AbstractDataFrame, Nothing}) = Kezdi.global_runtime_context = RuntimeContext(df, true) + display_and_return(x) = (display(x); x) """ diff --git a/src/structs.jl b/src/structs.jl index 27860d1..421b6c1 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -19,6 +19,20 @@ struct GeneratedCommand options::Vector{Any} end +struct CompileContext + scalars::Vector{Symbol} + flags::Set{Symbol} + with_block::Bool + line_number::Int +end +CompileContext() = CompileContext(Symbol[], DEFAULT_FLAGS, false, 0) + +struct RuntimeContext + df::Any + inplace::Bool +end +RuntimeContext(df) = RuntimeContext(df, false) + using DataFrames using Statistics using StatsBase diff --git a/test/commands.jl b/test/commands.jl index 991c3a8..329f0a1 100644 --- a/test/commands.jl +++ b/test/commands.jl @@ -704,5 +704,5 @@ end @test nrow(getdf()) == 3 @drop @if x == 1 - @test nrow(getdf()) == 2 + @test nrow(getdf()) == 1 end \ No newline at end of file