diff --git a/pyproject.toml b/pyproject.toml index 1bb3e82..34c3961 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "draw_tree" -version = "0.4.0" +version = "0.4.1" description = "Game tree drawing tool for extensive form games" readme = "README.md" requires-python = ">=3.7" diff --git a/src/draw_tree/__init__.py b/src/draw_tree/__init__.py index cd3a305..7ab6c95 100644 --- a/src/draw_tree/__init__.py +++ b/src/draw_tree/__init__.py @@ -5,7 +5,7 @@ from extensive form (.ef) files, with support for Jupyter notebooks. """ -__version__ = "0.4.0" +__version__ = "0.4.1" from .core import ( draw_tree, @@ -15,7 +15,7 @@ generate_png, ef_to_tex, latex_wrapper, - efg_dl_ef + efg_dl_ef, ) from .gambit_layout import gambit_layout_to_ef @@ -23,11 +23,11 @@ __all__ = [ "draw_tree", "generate_tikz", - "generate_tex", + "generate_tex", "generate_pdf", "generate_png", "ef_to_tex", "latex_wrapper", "efg_dl_ef", - "gambit_layout_to_ef" -] \ No newline at end of file + "gambit_layout_to_ef", +] diff --git a/src/draw_tree/core.py b/src/draw_tree/core.py index 779945c..cdcdc3b 100644 --- a/src/draw_tree/core.py +++ b/src/draw_tree/core.py @@ -4,6 +4,7 @@ This module provides functionality to generate TikZ code for game trees from extensive form (.ef) files, with support for Jupyter notebooks. """ + from __future__ import annotations import sys @@ -19,7 +20,7 @@ import pygambit from pathlib import Path -from typing import List, Optional +from typing import List, Optional from IPython.core.getipython import get_ipython from draw_tree.layout import DefaultLayout @@ -30,13 +31,19 @@ grid: bool = False maxplayer: int = 4 -payup: float = 0.1 # fraction of paydown to shift first payoff up -radius: float = 0.3 # iset radius +payup: float = 0.1 # fraction of paydown to shift first payoff up +radius: float = 0.3 # iset radius # Up to 4 players and chance (in principle more) # Default names playername: List[str] = [r"\small chance", "1", "2", "3", "4"] -playertexname: List[str] = ["playerzero", "playerone", "playertwo", "playerthree", "playerfour"] +playertexname: List[str] = [ + "playerzero", + "playerone", + "playertwo", + "playerthree", + "playerfour", +] # Player names that need to be defined in TeX playerdefined: List[bool] = [False] * (maxplayer + 1) @@ -114,10 +121,11 @@ def color_definitions() -> list[str]: "\\newcommand\\playersixcolor{magenta}", ] + def outall(stream: Optional[List[str]] = None) -> None: """ Output stream to stdout. - + Args: stream: List of strings to output. Defaults to global outstream. """ @@ -130,7 +138,7 @@ def outall(stream: Optional[List[str]] = None) -> None: def outs(s: str, stream: Optional[List[str]] = None) -> None: """ Output single string to stream. - + Args: s: String to append to stream. stream: Target stream list. Defaults to global outstream. @@ -143,7 +151,7 @@ def outs(s: str, stream: Optional[List[str]] = None) -> None: def outlist(string_list: List[str]) -> None: """ Output list of strings to global outstream. - + Args: string_list: List of strings to append to global outstream. """ @@ -154,11 +162,11 @@ def outlist(string_list: List[str]) -> None: def defout(defname: str, meaning: str) -> None: """ LaTeX command for defining something. - + Args: defname: Name of the definition. meaning: Value/meaning of the definition. - + Note: Outputs TeX definition. Consider changing to LaTeX \\newcommand*. """ @@ -168,7 +176,7 @@ def defout(defname: str, meaning: str) -> None: def newdimen(dimname: str, value: str) -> None: """ LaTeX command for creating a dimension. - + Args: dimname: Name of the dimension. value: Value of the dimension. @@ -180,7 +188,7 @@ def newdimen(dimname: str, value: str) -> None: def comment(s: str) -> None: """ Output comment if not suppressed. - + Args: s: Comment text to output. """ @@ -191,7 +199,7 @@ def comment(s: str) -> None: def error(s: str, stream: Optional[List[str]] = None) -> None: """ Output error message (errors not suppressed). - + Args: s: Error message text. stream: Target stream. Defaults to global outstream. @@ -200,23 +208,24 @@ def error(s: str, stream: Optional[List[str]] = None) -> None: stream = outstream outs("% ----- Error: " + s, stream) + def readfile(filename: str) -> List[str]: """ Read file lines, stripped of blanks at end, if non-empty, into list. - + Args: filename: Path to file to read. - + Returns: List of non-empty, stripped lines from the file. - + Raises: FileNotFoundError: If the file doesn't exist. - + Reference: http://stackoverflow.com/questions/12330522/reading-a-file-without-newlines """ - with open(filename, 'r') as file: + with open(filename, "r") as file: temp = file.read().splitlines() out = [] for line in temp: @@ -229,14 +238,14 @@ def readfile(filename: str) -> List[str]: def fformat(x: float, places: int = 3) -> str: """ Format float to specified places, remove trailing ".0". - + Args: x: Number to format. places: Number of decimal places (default: 3). - + Returns: Formatted string representation of the number. - + Examples: >>> fformat(3.14159) '3.142' @@ -258,14 +267,14 @@ def fformat(x: float, places: int = 3) -> str: def coord(x: float, y: float) -> str: """ Format coordinates as pair: 3,4 -> "(3,4)". - + Args: x: X coordinate. y: Y coordinate. - + Returns: Formatted coordinate string. - + Examples: >>> coord(1.0, 2.0) '(1,2)' @@ -276,13 +285,13 @@ def coord(x: float, y: float) -> str: def twonorm(v: List[float]) -> float: """ Calculate Euclidean length of vector. - + Args: v: Vector as list of coordinates. - + Returns: Euclidean length of the vector. - + Examples: >>> twonorm([3, 4]) 5.0 @@ -296,14 +305,14 @@ def twonorm(v: List[float]) -> float: def stretch(v: List[float], length: float = 1) -> List[float]: """ Stretch vector to desired length (must be >= 0). - + Args: v: Input vector. length: Desired length (default: 1). - + Returns: Stretched vector with specified length. - + Raises: AssertionError: If the result doesn't have the expected length. """ @@ -320,10 +329,10 @@ def stretch(v: List[float], length: float = 1) -> List[float]: def degrees(v: List[float]) -> float: """ Calculate angle of vector in degrees in (-180,180]. - + Args: v: Vector as list of coordinates. - + Returns: Angle in degrees. """ @@ -342,11 +351,11 @@ def degrees(v: List[float]) -> float: def aeq(x: float, y: float = 0) -> bool: """ Test if numbers are almost equal (or equal to zero) numerically. - + Args: x: First number. y: Second number (default: 0). - + Returns: True if numbers are approximately equal. """ @@ -356,46 +365,47 @@ def aeq(x: float, y: float = 0) -> bool: def det(a: float, b: float, c: float, d: float) -> float: """ Calculate determinant of 2x2 matrix. - + Args: a, b, c, d: Matrix elements [[a, b], [c, d]]. - + Returns: Determinant value (ad - bc). """ return a * d - b * c + def isonlineseg(a: List[float], b: List[float], c: List[float]) -> bool: """ Check if point b lies on the line segment [a,c]. - + Args: a: Starting point as [x, y] coordinates. b: Point to test as [x, y] coordinates. c: Ending point as [x, y] coordinates. - + Returns: True if point b is on the line segment from a to c, False otherwise. """ - bx=b[0]-a[0] - by=b[1]-a[1] - cx=c[0]-a[0] - cy=c[1]-a[1] + bx = b[0] - a[0] + by = b[1] - a[1] + cx = c[0] - a[0] + cy = c[1] - a[1] if aeq(bx) and aeq(by): return True # a near b - if aeq( bx*cy - by*cx ): # collinear - if aeq(cx) and aeq(cy) : # a near c but not near b + if aeq(bx * cy - by * cx): # collinear + if aeq(cx) and aeq(cy): # a near c but not near b return False - if aeq(cx): # look at y coordinate - if aeq(by,cy): - return True # c near b + if aeq(cx): # look at y coordinate + if aeq(by, cy): + return True # c near b if cy >= 0: return (by >= 0) and (by <= cy) # cy < 0 return (by <= 0) and (by >= cy) # nonzero x coordinate of c, gives info - if aeq(bx,cx): - return True # c near b + if aeq(bx, cx): + return True # c near b if cx > 0: return (bx >= 0) and (bx <= cx) # cx < 0 @@ -403,21 +413,24 @@ def isonlineseg(a: List[float], b: List[float], c: List[float]) -> bool: # not collinear return False -def makearc(a: List[float], b: List[float], c: List[float], radius: float = isetradius) -> str: + +def makearc( + a: List[float], b: List[float], c: List[float], radius: float = isetradius +) -> str: """ Create arc or point around point b in triangle a,b,c. - + Args: a: First point as [x, y] coordinates. b: Center point as [x, y] coordinates. c: Third point as [x, y] coordinates. radius: Radius for the arc. Defaults to isetradius. - + Returns: TikZ coordinate string for the arc or point. """ - s = stretch([ b[1]-a[1], a[0]-b[0] ], radius) - t = stretch([ c[1]-b[1], b[0]-c[0] ], radius) + s = stretch([b[1] - a[1], a[0] - b[0]], radius) + t = stretch([c[1] - b[1], b[0] - c[0]], radius) # print "% s,t ", s,t sangle = degrees(s) tangle = degrees(t) @@ -427,20 +440,20 @@ def makearc(a: List[float], b: List[float], c: List[float], radius: float = iset sx = b[0] + s[0] sy = b[1] + s[1] # tikz code - out = coord(sx,sy) + " arc(" - out += fformat(sangle,1) + ":" - out += fformat(tangle,1) + ":" + out = coord(sx, sy) + " arc(" + out += fformat(sangle, 1) + ":" + out += fformat(tangle, 1) + ":" out += fformat(radius) + ")" # checking if point rather than arc - # print "% tangle-sangle ", tangle-sangle - if tangle-sangle > 180.01: + # print "% tangle-sangle ", tangle-sangle + if tangle - sangle > 180.01: tx = b[0] + t[0] ty = b[1] + t[1] - if tangle-sangle > 359: # very close to straight + if tangle - sangle > 359: # very close to straight # print "% 359" - x=(sx+tx)/2 - y=(sy+ty)/2 - out = coord(x,y) + x = (sx + tx) / 2 + y = (sy + ty) / 2 + out = coord(x, y) else: ax = a[0] + s[0] ay = a[1] + s[1] @@ -448,107 +461,111 @@ def makearc(a: List[float], b: List[float], c: List[float], radius: float = iset cy = c[1] + t[1] # print "% sx,sy,tx,ty", sx,sy,tx,ty # print "% ax,ay,cx,cy", ax,ay,cx,cy - D = det (sx-ax,sy-ay,cx-tx,cy-ty) + D = det(sx - ax, sy - ay, cx - tx, cy - ty) if not aeq(D): # zero determinant - do nothing - alpha = det(cx-ax,cy-ay,cx-tx,cy-ty) / D - beta = det(sx-ax,sy-ay,cx-ax,cy-ay) / D + alpha = det(cx - ax, cy - ay, cx - tx, cy - ty) / D + beta = det(sx - ax, sy - ay, cx - ax, cy - ay) / D # print "% alpha ", alpha # print "% beta ", beta - assert (alpha<1) - assert (beta<1) - ## trying to salvage tight angles, other solution is better - # if alpha<0: - # x = ax - # y = ay - # elif beta<0: - # x = cx - # y = cy - # else : - # x = ax + (sx-ax)*alpha - # y = ay + (sy-ay)*alpha - # out = coord(x,y) - if alpha >= 0 and beta >= 0 : - x = ax + (sx-ax)*alpha - y = ay + (sy-ay)*alpha - out = coord(x,y) - return out + assert alpha < 1 + assert beta < 1 + ## trying to salvage tight angles, other solution is better + # if alpha<0: + # x = ax + # y = ay + # elif beta<0: + # x = cx + # y = cy + # else : + # x = ax + (sx-ax)*alpha + # y = ay + (sy-ay)*alpha + # out = coord(x,y) + if alpha >= 0 and beta >= 0: + x = ax + (sx - ax) * alpha + y = ay + (sy - ay) * alpha + out = coord(x, y) + return out + def arcseq(nodes: List[List[float]], radius: float = isetradius) -> List[str]: """ Create a list of TikZ drawing commands around a list of coordinate pairs. - + Creates a sequence of arcs around the given nodes, removing collinear points and handling singleton information sets appropriately. - + Args: nodes: List of coordinate pairs [x,y]. radius: Radius for the arcs. Defaults to isetradius. - + Returns: List of TikZ command strings (without "draw" and ";" wrapper). - """ - nodes = nodes[:] # protect nodes parameter, now a local variable + """ + nodes = nodes[:] # protect nodes parameter, now a local variable if len(nodes) == 0: return [""] - if len(nodes) == 1: # singleton info set + if len(nodes) == 1: # singleton info set x = nodes[0][0] y = nodes[0][1] # circle only? - if aeq(xsingleiset) and aeq(ysingleiset): # no offset - # tikz code - s = coord(x,y) + " circle [radius=" + if aeq(xsingleiset) and aeq(ysingleiset): # no offset + # tikz code + s = coord(x, y) + " circle [radius=" s += fformat(radius) + "cm]" return [s] # else extend with extra point - else: - nodes.append([x+xsingleiset,y+ysingleiset]) + else: + nodes.append([x + xsingleiset, y + ysingleiset]) # now at least length 2 # successively remove points on same line segment a = nodes.pop(0) b = nodes.pop(0) newnodes = [a] - while (nodes): + while nodes: c = nodes.pop(0) - if not isonlineseg(a,b,c): + if not isonlineseg(a, b, c): newnodes.append(b) a = b - b=c + b = c newnodes.append(b) - tour = newnodes[1:2]+newnodes[:-1]+newnodes[::-1] + tour = newnodes[1:2] + newnodes[:-1] + newnodes[::-1] out = [] - for i in range(1, len(tour)-1): - out.append(makearc(tour[i-1],tour[i],tour[i+1],radius)) - return out + for i in range(1, len(tour) - 1): + out.append(makearc(tour[i - 1], tour[i], tour[i + 1], radius)) + return out + def iset(nodes: List[List[float]], radius: float = isetradius) -> str: """ Create complete TikZ drawing commands for an information set. - + Args: nodes: List of coordinate pairs [x,y]. radius: Radius for the arcs. Defaults to isetradius. - + Returns: Complete TikZ draw command string with semicolon. - """ - arcs = arcseq(nodes,radius) - # tikz code + """ + arcs = arcseq(nodes, radius) + # tikz code return "\\draw [" + isetparams + "] " + "\n -- ".join(arcs) + " -- cycle;" + ######################## handling players + def player(words: List[str]) -> tuple[int, int]: """ Parse 'player' command and handle player definitions. - + Processes player number and optional name, writing out player definition if the player is named or used for the first time. - + Args: words: List of command words starting with 'player'. - + Returns: - Tuple of (player_number, advance_count) where advance_count is + Tuple of (player_number, advance_count) where advance_count is the number of words consumed from the input. """ p = -1 # illegal player @@ -560,25 +577,26 @@ def player(words: List[str]) -> tuple[int, int]: error("need player number after 'player'") return p, advance if x < 0 or x > maxplayer: - error("need player number in 0.."+str(maxplayer)+" after 'player'") - advance = 2 # allow continued processing + error("need player number in 0.." + str(maxplayer) + " after 'player'") + advance = 2 # allow continued processing return p, advance p = x if len(words) > 2: if words[2] == "name": - if len(words) == 3: # nothing there + if len(words) == 3: # nothing there error("player name needed after 'name'") return p, advance - playername[p] = words[3] # got new player name + playername[p] = words[3] # got new player name playerdefined[p] = False advance = 4 else: - advance = 2 # only "player p" parsed + advance = 2 # only "player p" parsed if not playerdefined[p]: defout(playertexname[p], playername[p]) playerdefined[p] = True return p, advance + ######################## handling nodes # each node is itself a dict, with the fields @@ -587,23 +605,24 @@ def player(words: List[str]) -> tuple[int, int]: nodes = {} xshifts = {} + def splitnumtext(s: str) -> tuple[float, str]: """ Split a string into numeric prefix and text remainder. - + Extracts a leading number (including decimal) from a string and returns both the number and the remaining text. - + Args: s: Input string to parse. - + Returns: - Tuple of (number, remainder_text). If no number is found, + Tuple of (number, remainder_text). If no number is found, returns (1, original_string). - + Examples: "2.3abc" -> (2.3, "abc") - ".1b" -> (0.1, "b") + ".1b" -> (0.1, "b") "a" -> (1, "a") """ nodotyet = True @@ -628,16 +647,17 @@ def splitnumtext(s: str) -> tuple[float, str]: # print s, splitnumtext(s) # quit() + def xshift(words: List[str]) -> tuple[float, float, int]: """ Parse 'xshift' command to determine horizontal positioning. - + Handles xshift assignments and lookups, including named xshift variables and coefficient multipliers. - + Args: words: List of command words starting with 'xshift'. - + Returns: Tuple of (x_shift, factor, advance_count) where: - x_shift: The calculated horizontal shift value @@ -662,31 +682,30 @@ def xshift(words: List[str]) -> tuple[float, float, int]: try: num = float(a[1]) except ValueError: - error("assigment '"+ a[1] + "' must be a number") + error("assigment '" + a[1] + "' must be a number") return xs, 1, advance coeff, xsname = splitnumtext(a[0]) if xsname in xshifts: - comment("Warning: xshift '" + xsname + \ - "' re-defined to "+str(num)) + comment("Warning: xshift '" + xsname + "' re-defined to " + str(num)) xshifts[xsname] = num num *= coeff else: coeff, xsname = splitnumtext(a[0]) - if xsname: # uses a name + if xsname: # uses a name if xsname not in xshifts: error("xshift '" + xsname + "' undefined") return xs, 1, advance num = coeff * xshifts[xsname] else: num = coeff - coeff = 1 # no use of factor without label - if aeq(num): # nearly zero + coeff = 1 # no use of factor without label + if aeq(num): # nearly zero xs = 0 - if aeq(coeff): # coefficient nearly zero + if aeq(coeff): # coefficient nearly zero factor = 1 else: factor = coeff - else: # num nonzero and therefore coeff nonzero + else: # num nonzero and therefore coeff nonzero factor = coeff if neg: xs = -num @@ -703,13 +722,14 @@ def xshift(words: List[str]) -> tuple[float, float, int]: # print outstream # quit() + def fromnode(words: List[str]) -> tuple[str, int]: """ Parse 'from' command to identify parent node. - + Args: words: List of command words starting with 'from'. - + Returns: Tuple of (parent_node_id, advance_count) where parent_node_id is the cleaned node identifier and advance_count is words consumed. @@ -722,22 +742,23 @@ def fromnode(words: List[str]) -> tuple[str, int]: return fromn, advance s = cleannodeid(words[1]) if s not in nodes: - error("node "+s+" after 'from' is not defined") + error("node " + s + " after 'from' is not defined") else: fromn = s advance = 2 return fromn, advance + def move(words: List[str]) -> tuple[str, str, float, int]: """ Parse 'move' command to extract move name and positioning. - + Handles move syntax like "move:Left:0.3" where the colon-separated parts specify positioning and convexity parameters. - + Args: words: List of command words starting with 'move'. - + Returns: Tuple of (move_name, move_position, convex_value, advance_count). """ @@ -748,7 +769,7 @@ def move(words: List[str]) -> tuple[str, str, float, int]: convex = -1 a = words[0].split(":") if len(a) > 1: - movpos = (a[1]+" ")[0].lower() # first character only + movpos = (a[1] + " ")[0].lower() # first character only if len(a) > 2: try: num = float(a[2]) @@ -771,13 +792,14 @@ def move(words: List[str]) -> tuple[str, str, float, int]: # outall() # quit ("done testing.") + def arrow(words: List[str]) -> tuple[float, str, int]: """ Parse 'arrow' command to extract arrow positioning and color. - + Args: words: List of command words starting with 'arrow'. - + Returns: Tuple of (arrow_position, arrow_color, advance_count). """ @@ -799,28 +821,29 @@ def arrow(words: List[str]) -> tuple[float, str, int]: error("Arrow position in [0,1] required, using 0.5") return arrowpos, arrowcolor, advance + def payoffs(words: List[str]) -> List[str]: """ Parse 'payoffs' command to generate TikZ payoff display code. - + Args: words: List of command words starting with 'payoffs'. - + Returns: List of TikZ node commands for displaying payoffs. """ assert words[0] == "payoffs" maxp = len(words) - if len(words) > maxplayer+1: - error("too many payoffs, discard "+str(words[maxplayer+1:])) - maxp = maxplayer+1 + if len(words) > maxplayer + 1: + error("too many payoffs, discard " + str(words[maxplayer + 1 :])) + maxp = maxplayer + 1 paylist = [] for i in range(1, maxp): # tikz code t = " node[below,yshift=" - t += fformat(payup-(i-1)) + paydown + t += fformat(payup - (i - 1)) + paydown t += "] {$" + words[i] - if words[i][0] == "-": # negative payoff + if words[i][0] == "-": # negative payoff t += "{\\phantom-}" t += "$\\strut}" paylist.append(t) @@ -833,17 +856,18 @@ def payoffs(words: List[str]) -> List[str]: # print s # quit() + def drawnode(v: List[float], player: int = 1, color_scheme: str = "default") -> str: """ Generate TikZ code to draw a game tree node. - + Creates either a square (for chance/player 0) or circle (for other players). - + Args: v: Node position as [x, y] coordinates. player: Player number (0 for chance node, >0 for player node). color_scheme: Color scheme for player nodes. - + Returns: TikZ node command string. """ @@ -867,6 +891,7 @@ def drawnode(v: List[float], player: int = 1, color_scheme: str = "default") -> outs(out) return out + def drawnodes(color_scheme: str = "default") -> None: """ Draw all inner (non-leaf) nodes in the game tree. @@ -886,32 +911,34 @@ def drawnodes(color_scheme: str = "default") -> None: p = nodes[nodeid]["player"] drawnode(v, p, color_scheme) + def setnodeid(lev: float, s: str) -> str: """ Create node identifier from level and name. - + Args: lev: Level number (typically a float). s: Name string for the node. - + Returns: Formatted node identifier string "level,name". """ - return fformat(lev)+","+s + return fformat(lev) + "," + s + def cleannodeid(ns: str) -> str: """ Standardize node id from "level,name" format. - + Args: ns: Node string in "level,name" format. - + Returns: Standardized node identifier. """ a = ns.split(",") if len(a) < 2: - error("missing comma in '"+ns+"', using empty node id") + error("missing comma in '" + ns + "', using empty node id") s = "" else: s = a[1] @@ -930,6 +957,7 @@ def cleannodeid(ns: str) -> str: # print outstream # quit() + # handle "level" keyword; # commands: "node" node , then in any order # "xshift" [-][2][[a=]1.5|a] (2= multiple, a= xshift name, 1.5 = dimen) @@ -938,6 +966,7 @@ def cleannodeid(ns: str) -> str: # "payoffs" list of payoffs, comes last # "inner" boolean: inner node, draw disk/square + def parse_isets_first(lines: List[str]) -> None: """ Pre-parse all iset commands to build node-to-player mappings. @@ -976,6 +1005,7 @@ def parse_isets_first(lines: List[str]) -> None: for nodeid in nodes_in_iset: node_to_iset_player[nodeid] = p + def generate_legend( player_list: List[int], color_scheme: str = "gambit", scale_factor: float = 1.0 ) -> str: @@ -1036,11 +1066,12 @@ def generate_legend( return legend_code + def level( - words: List[str], - color_scheme: str = "default", - action_label_position: float = 0.5, - ) -> None: + words: List[str], + color_scheme: str = "default", + action_label_position: float = 0.5, +) -> None: """ Process a complete level command to create a game tree node. @@ -1106,7 +1137,7 @@ def level( else: # unknown keyword error("unknown keyword " + words[count]) count += 1 - + # If move contains a float, apply fformat if "~" in mov: movlist = mov.split("~") @@ -1182,9 +1213,7 @@ def level( outs(" -- " + coord(xfrom, yfrom) + ";") # annotate moves above if convex < 0: - convex = ( - action_label_position / factor - ) + convex = action_label_position / factor xmove = xx * convex + xfrom * (1 - convex) ymove = yy * convex + yfrom * (1 - convex) s = "\\draw " + coord(xmove, ymove) @@ -1225,8 +1254,10 @@ def level( outs(" ;") return + ######################## isets + def isetgen(words: List[str], color_scheme: str = "default") -> None: """ Process 'iset' command to generate information set visualization. @@ -1311,18 +1342,22 @@ def isetgen(words: List[str], color_scheme: str = "default") -> None: outs(s) return + ########### command-line arguments -def commandline(argv: List[str]) -> tuple[str, bool, bool, bool, Optional[str], Optional[int]]: + +def commandline( + argv: List[str], +) -> tuple[str, bool, bool, bool, Optional[str], Optional[int]]: """ Process command-line arguments to set global configuration. - + Sets global variables for ef_file, scale, and grid based on command-line arguments. Also detects if PDF or PNG output is requested. - + Args: argv: List of command-line arguments (including script name). - + Returns: Tuple of (output_mode, pdf_requested, png_requested, tex_requested, output_file, dpi) where: - output_mode: 'tikz', 'pdf', 'png', or 'tex' @@ -1333,26 +1368,31 @@ def commandline(argv: List[str]) -> tuple[str, bool, bool, bool, Optional[str], - dpi: DPI setting for PNG output (None if not specified) """ global grid - global scale + global scale global ef_file - + pdf_requested = False png_requested = False tex_requested = False output_file = None dpi = None - + for arg in argv[1:]: if arg[:5] == "scale": a = arg.split("=") - try: + try: num = float(a[1]) if num >= 0.01 and num <= 100: scale = num - else: - outs("% Command-line argument 'scale=x' needs x in 0.01 .. 100", stream0) + else: + outs( + "% Command-line argument 'scale=x' needs x in 0.01 .. 100", + stream0, + ) except Exception: - outs("% Command-line argument 'scale=x' needs x in 0.01 .. 100", stream0) + outs( + "% Command-line argument 'scale=x' needs x in 0.01 .. 100", stream0 + ) elif arg == "grid": grid = True elif arg == "--pdf": @@ -1363,27 +1403,30 @@ def commandline(argv: List[str]) -> tuple[str, bool, bool, bool, Optional[str], tex_requested = True elif arg.startswith("--output="): output_file = arg[9:] # Remove "--output=" prefix - if output_file.endswith('.pdf'): + if output_file.endswith(".pdf"): pdf_requested = True - elif output_file.endswith('.png'): + elif output_file.endswith(".png"): png_requested = True - elif output_file.endswith('.tex'): + elif output_file.endswith(".tex"): tex_requested = True elif arg.startswith("--dpi="): try: dpi = int(arg[6:]) # Remove "--dpi=" prefix if dpi < 72 or dpi > 2400: - print("Warning: DPI should be between 72 and 2400, using default 300", file=sys.stderr) + print( + "Warning: DPI should be between 72 and 2400, using default 300", + file=sys.stderr, + ) dpi = 300 except ValueError: print("Warning: Invalid DPI value, using default 300", file=sys.stderr) dpi = 300 - elif arg.endswith('.ef'): + elif arg.endswith(".ef"): ef_file = arg else: # For backward compatibility, treat unknown args as filenames ef_file = arg - + # Determine output mode if png_requested: output_mode = "png" @@ -1393,9 +1436,10 @@ def commandline(argv: List[str]) -> tuple[str, bool, bool, bool, Optional[str], output_mode = "tex" else: output_mode = "tikz" - + return (output_mode, pdf_requested, png_requested, tex_requested, output_file, dpi) + def ef_to_tex( ef_file: str, scale_factor: float = 1.0, @@ -1420,7 +1464,7 @@ def ef_to_tex( Complete TikZ code as a string. """ # Scale adjustment - scale_factor = scale_factor*0.8 + scale_factor = scale_factor * 0.8 global scale, grid, node_to_iset_player @@ -1515,6 +1559,7 @@ def ef_to_tex( scale = original_scale grid = original_grid + def generate_tikz( game: str | "pygambit.gambit.Game", save_to: Optional[str] = None, @@ -1554,7 +1599,7 @@ def generate_tikz( # it successfully writes the .ef file. ef_file = game if isinstance(game, str): - if game.lower().endswith('.efg'): + if game.lower().endswith(".efg"): try: ef_file = efg_dl_ef(game) except Exception: @@ -1562,20 +1607,23 @@ def generate_tikz( pass else: from .gambit_layout import gambit_layout_to_ef + # Generate the ef, use normalised spacing options ef_file = gambit_layout_to_ef( game, save_to=save_to, - level_multiplier=level_scaling*4, - sublevel_multiplier=sublevel_scaling*2 , - xshift_multiplier=width_scaling*2, + level_multiplier=level_scaling * 4, + sublevel_multiplier=sublevel_scaling * 2, + xshift_multiplier=width_scaling * 2, hide_action_labels=hide_action_labels, shared_terminal_depth=shared_terminal_depth, ) # Step 1: Generate the tikzpicture content using ef_to_tex logic - tikz_picture_content = ef_to_tex(ef_file, scale_factor, show_grid, color_scheme, action_label_position) - + tikz_picture_content = ef_to_tex( + ef_file, scale_factor, show_grid, color_scheme, action_label_position + ) + # Step 2: Define built-in macro definitions (from macros-drawtree.tex) macro_definitions = [ "\\newdimen\\ndiam", @@ -1680,10 +1728,10 @@ def draw_tree( # Execute cell magic or return TikZ ip = get_ipython() if ip: - em = getattr(ip, 'extension_manager', None) - loaded = getattr(em, 'loaded', None) + em = getattr(ip, "extension_manager", None) + loaded = getattr(em, "loaded", None) try: - jpt_loaded = 'jupyter_tikz' in loaded # type: ignore + jpt_loaded = "jupyter_tikz" in loaded # type: ignore except Exception: jpt_loaded = False if not jpt_loaded: @@ -1696,7 +1744,7 @@ def draw_tree( def latex_wrapper(tikz_code: str) -> str: """ Wrap TikZ code in a complete LaTeX document. - + Args: tikz_code: The TikZ code to embed in the document. Returns: @@ -1762,14 +1810,14 @@ def generate_tex( if isinstance(game, str): game_path = Path(game) else: - game_path = Path(game.title + '.ef') - output_tex = game_path.with_suffix('.tex').name + game_path = Path(game.title + ".ef") + output_tex = game_path.with_suffix(".tex").name else: - if not save_to.endswith('.tex'): - output_tex = save_to + '.tex' + if not save_to.endswith(".tex"): + output_tex = save_to + ".tex" else: output_tex = save_to - + # If game is an EFG file, convert it first if isinstance(game, str) and game.lower().endswith(".efg"): try: @@ -1792,31 +1840,31 @@ def generate_tex( edge_thickness=edge_thickness, action_label_position=action_label_position, ) - + # Wrap in complete LaTeX document latex_document = latex_wrapper(tikz_code) - + # Write to file - with open(output_tex, 'w') as f: + with open(output_tex, "w") as f: f.write(latex_document) - + return str(Path(output_tex).absolute()) def generate_pdf( - game: str | "pygambit.gambit.Game", - save_to: Optional[str] = None, - scale_factor: float = 1.0, - level_scaling: int = 1, - sublevel_scaling: int = 1, - width_scaling: int = 1, - hide_action_labels: bool = False, - shared_terminal_depth: bool = False, - show_grid: bool = False, - color_scheme: str = "default", - edge_thickness: float = 1.0, - action_label_position: float = 0.5, - ) -> str: + game: str | "pygambit.gambit.Game", + save_to: Optional[str] = None, + scale_factor: float = 1.0, + level_scaling: int = 1, + sublevel_scaling: int = 1, + width_scaling: int = 1, + hide_action_labels: bool = False, + shared_terminal_depth: bool = False, + show_grid: bool = False, + color_scheme: str = "default", + edge_thickness: float = 1.0, + action_label_position: float = 0.5, +) -> str: """ Generate a PDF directly from an extensive form (.ef) file. @@ -1863,7 +1911,7 @@ def generate_pdf( game = efg_dl_ef(game) except Exception: pass - + # Generate TikZ content using generate_tikz tikz_code = generate_tikz( game, @@ -1879,47 +1927,58 @@ def generate_pdf( edge_thickness=edge_thickness, action_label_position=action_label_position, ) - + # Create LaTeX wrapper document latex_document = latex_wrapper(tikz_code) - + # Use temporary directory for LaTeX compilation with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) - + # Write LaTeX file tex_file = temp_path / "output.tex" - with open(tex_file, 'w', encoding='utf-8') as f: + with open(tex_file, "w", encoding="utf-8") as f: f.write(latex_document) - + # Compile with pdflatex try: - subprocess.run([ - 'pdflatex', - '-interaction=nonstopmode', - '-output-directory', str(temp_path), - str(tex_file) - ], capture_output=True, text=True, check=True) - + subprocess.run( + [ + "pdflatex", + "-interaction=nonstopmode", + "-output-directory", + str(temp_path), + str(tex_file), + ], + capture_output=True, + text=True, + check=True, + ) + # Move the generated PDF to the desired location generated_pdf = temp_path / "output.pdf" final_pdf_path = Path(output_pdf) - + if generated_pdf.exists(): # Copy to final destination import shutil + shutil.copy2(generated_pdf, final_pdf_path) return str(final_pdf_path.absolute()) else: raise RuntimeError("PDF was not generated successfully") - + except subprocess.CalledProcessError as e: error_msg = f"LaTeX compilation failed:\n{e.stderr}" if "command not found" in e.stderr or "No such file" in str(e): - error_msg += "\n\nMake sure pdflatex is installed and available in your PATH." + error_msg += ( + "\n\nMake sure pdflatex is installed and available in your PATH." + ) raise RuntimeError(error_msg) except FileNotFoundError: - raise RuntimeError("pdflatex not found. Please install a LaTeX distribution (e.g., TeX Live, MiKTeX).") + raise RuntimeError( + "pdflatex not found. Please install a LaTeX distribution (e.g., TeX Live, MiKTeX)." + ) def generate_png( @@ -1983,16 +2042,16 @@ def generate_png( game = efg_dl_ef(game) except Exception: pass - + # Step 1: Generate PDF first with tempfile.TemporaryDirectory() as temp_dir: - temp_pdf = Path(temp_dir) / "temp_output.pdf" - + temp_pdf = str(Path(temp_dir) / "temp_output.pdf") + try: # Generate PDF using existing function generate_pdf( game=game, - save_to=save_to, + save_to=temp_pdf, scale_factor=scale_factor, level_scaling=level_scaling, sublevel_scaling=sublevel_scaling, @@ -2004,67 +2063,80 @@ def generate_png( edge_thickness=edge_thickness, action_label_position=action_label_position, ) - + # Step 2: Convert PDF to PNG final_png_path = Path(output_png) - + # Try different conversion methods in order of preference conversion_success = False - + # Method 1: Try ImageMagick convert try: - subprocess.run([ - 'convert', - '-density', str(dpi), - '-quality', '100', - str(temp_pdf), - str(final_png_path) - ], capture_output=True, text=True, check=True) + subprocess.run( + [ + "convert", + "-density", + str(dpi), + "-quality", + "100", + str(temp_pdf), + str(final_png_path), + ], + capture_output=True, + text=True, + check=True, + ) conversion_success = True except (subprocess.CalledProcessError, FileNotFoundError): pass - + # Method 2: Try Ghostscript if ImageMagick failed if not conversion_success: try: - subprocess.run([ - 'gs', - '-dNOPAUSE', - '-dBATCH', - '-sDEVICE=png16m', - f'-r{dpi}', - f'-sOutputFile={final_png_path}', - str(temp_pdf) - ], capture_output=True, text=True, check=True) + subprocess.run( + [ + "gs", + "-dNOPAUSE", + "-dBATCH", + "-sDEVICE=png16m", + f"-r{dpi}", + f"-sOutputFile={final_png_path}", + str(temp_pdf), + ], + capture_output=True, + text=True, + check=True, + ) conversion_success = True except (subprocess.CalledProcessError, FileNotFoundError): pass - + # Method 3: Try pdftoppm + convert if available if not conversion_success: try: temp_ppm = Path(temp_dir) / "temp_output" # Convert PDF to PPM first - subprocess.run([ - 'pdftoppm', - '-r', str(dpi), - str(temp_pdf), - str(temp_ppm) - ], capture_output=True, text=True, check=True) - + subprocess.run( + ["pdftoppm", "-r", str(dpi), str(temp_pdf), str(temp_ppm)], + capture_output=True, + text=True, + check=True, + ) + # Find the generated PPM file (pdftoppm adds -1.ppm suffix) ppm_file = Path(temp_dir) / f"{temp_ppm.name}-1.ppm" if ppm_file.exists(): # Convert PPM to PNG - subprocess.run([ - 'convert', - str(ppm_file), - str(final_png_path) - ], capture_output=True, text=True, check=True) + subprocess.run( + ["convert", str(ppm_file), str(final_png_path)], + capture_output=True, + text=True, + check=True, + ) conversion_success = True except (subprocess.CalledProcessError, FileNotFoundError): pass - + if not conversion_success: raise RuntimeError( "PNG conversion failed. Please install one of the following:\n" @@ -2076,12 +2148,12 @@ def generate_png( " Ubuntu: sudo apt-get install imagemagick ghostscript poppler-utils\n" " Windows: Install ImageMagick or Ghostscript from their websites" ) - + if final_png_path.exists(): return str(final_png_path.absolute()) else: raise RuntimeError("PNG was not generated successfully") - + except FileNotFoundError: # Re-raise file not found errors directly raise @@ -2110,7 +2182,6 @@ def efg_dl_ef(efg_file: str) -> str: lines = readfile(efg_file) - # Extract players from header if present. header = "\n".join(lines[:5]) m_players = re.search(r"\{\s*([\s\S]*?)\s*\}", header) @@ -2122,7 +2193,7 @@ def efg_dl_ef(efg_file: str) -> str: descriptors = [] for raw in lines: line = raw.strip() - if not line or line.startswith('%') or line.startswith('#'): + if not line or line.startswith("%") or line.startswith("#"): continue tokens = line.split() if not tokens: @@ -2134,13 +2205,13 @@ def efg_dl_ef(efg_file: str) -> str: probs = [] payoffs = [] player = None - if kind == 'c' or kind == 'p': + if kind == "c" or kind == "p": if brace: moves = re.findall(r'"([^"\\]*)"', brace.group(1)) # also extract probabilities (numbers) in brace - probs = re.findall(r'([0-9]+\/[0-9]+|[0-9]*\.?[0-9]+)', brace.group(1)) + probs = re.findall(r"([0-9]+\/[0-9]+|[0-9]*\.?[0-9]+)", brace.group(1)) # attempt to find player id for 'p' lines - if kind == 'p': + if kind == "p": # find first integer token after type nums = [t for t in tokens[1:] if t.isdigit()] if len(nums) >= 1: @@ -2151,17 +2222,17 @@ def efg_dl_ef(efg_file: str) -> str: iset_id = int(nums[1]) else: iset_id = None - elif kind == 't': + elif kind == "t": # terminal: extract payoffs (allow integers and decimals) if brace: # Match floats like 12.80, .80, -1.5 or integers like 3 - pay_tokens = re.findall(r'(-?\d*\.\d+|-?\d+)', brace.group(1)) + pay_tokens = re.findall(r"(-?\d*\.\d+|-?\d+)", brace.group(1)) payoffs = [] for tok in pay_tokens: # If token contains a decimal point treat as float and # format with two decimal places (keeps trailing zeros), # otherwise treat as integer. - if '.' in tok: + if "." in tok: try: v = float(tok) payoffs.append("{:.2f}".format(v)) @@ -2173,18 +2244,20 @@ def efg_dl_ef(efg_file: str) -> str: payoffs.append(str(int(tok))) except Exception: payoffs.append(tok) - descriptors.append({ - 'kind': kind, - 'player': player, - 'moves': moves, - 'probs': probs, - 'payoffs': payoffs, - 'iset_id': locals().get('iset_id', None), - 'raw': line, - }) + descriptors.append( + { + "kind": kind, + "player": player, + "moves": moves, + "probs": probs, + "payoffs": payoffs, + "iset_id": locals().get("iset_id", None), + "raw": line, + } + ) # Filter descriptors to only the game records (c, p, t) - descriptors = [d for d in descriptors if d['kind'] in ('c', 'p', 't')] + descriptors = [d for d in descriptors if d["kind"] in ("c", "p", "t")] # Layout/emission: delegate to DefaultLayout class for clarity/testability layout = DefaultLayout(descriptors, player_names) @@ -2192,9 +2265,9 @@ def efg_dl_ef(efg_file: str) -> str: try: efg_path = Path(efg_file) - out_path = efg_path.with_suffix('.ef') - with open(out_path, 'w', encoding='utf-8') as f: - f.write('\n'.join(out_lines) + '\n') + out_path = efg_path.with_suffix(".ef") + with open(out_path, "w", encoding="utf-8") as f: + f.write("\n".join(out_lines) + "\n") return str(out_path) except Exception: - return '\n'.join(out_lines) + return "\n".join(out_lines)