diff --git a/SHELL_FEATURES.md b/SHELL_FEATURES.md index 9236b3c3..144fa05f 100644 --- a/SHELL_FEATURES.md +++ b/SHELL_FEATURES.md @@ -7,7 +7,7 @@ The in-shell `help` command mirrors these feature categories: run `help` for a c ## Builtins -- ✅ `awk [-F SEP] [-v NAME=VALUE] ['PROGRAM'|-f PROGRAM-FILE] [FILE]...` — pattern scanning and text processing; supports BEGIN/main/END rules, fields and field mutation (`$0`, `$1`, `$NF`), `NF`/`NR`/`FNR`/`FILENAME`, `FS`/`OFS`/`ORS`, regex `FS`, `print`, `printf`, scalar and associative array assignment, `split`, `in`, `delete`, `for`, `while`, `break`, `continue`, range patterns, arithmetic/comparison/boolean expressions, regex patterns and `~`/`!~`, string concatenation, `if`/`else`, `next`, `ENVIRON`, and scalar builtins (`length`, `substr`, `index`, `tolower`, `toupper`, `int`); `system()`, command pipes, output redirection, `getline`, user-defined functions, and many POSIX/GNU awk builtins remain rejected or deferred +- ✅ `awk [-F SEP] [-v NAME=VALUE] ['PROGRAM'|-f PROGRAM-FILE] [FILE]...` — pattern scanning and text processing; supports BEGIN/main/END rules, fields and field mutation (`$0`, `$1`, `$NF`), `NF`/`NR`/`FNR`/`FILENAME`, `FS`/`RS`/`OFS`/`ORS`/`SUBSEP`, `RSTART`/`RLENGTH`, regex `FS`, single-character `RS`, `IGNORECASE`, `print`, `printf`, `sprintf`, scalar and associative array assignment, composite array keys, `split`, `sub`, `gsub`, `gensub`, `match` with capture arrays, `strtonum`, `asorti`, `in`, `delete`, `for`, `while`, `break`, `continue`, `exit`, range patterns, arithmetic/comparison/boolean/ternary expressions, regex patterns and `~`/`!~`, string concatenation, `if`/`else`, `next`, `ENVIRON`, user-defined functions with `return` and scalar or array parameters, current/file/command-pipe `getline`, output command pipes through rshell builtins, and scalar builtins (`length`, `substr`, `index`, `tolower`, `toupper`, `int`); `system()`, file output redirection, ARGV/ARGC mutation, BEGINFILE/ENDFILE, `nextfile`, include/load, namespaces, indirect calls, FIELDWIDTHS/FPAT/CSV mode, PROCINFO/SYMTAB/FUNCTAB, extension loading, and many POSIX/GNU awk utility builtins remain rejected or deferred - ✅ `break` — exit the innermost `for` loop - ✅ `cat [-AbeEnstTuv] [FILE]...` — concatenate files to stdout; supports line numbering, blank squeezing, and non-printing character display - ✅ `continue` — skip to the next iteration of the innermost `for` loop diff --git a/analysis/symbols_builtins.go b/analysis/symbols_builtins.go index 20485cfa..024aed33 100644 --- a/analysis/symbols_builtins.go +++ b/analysis/symbols_builtins.go @@ -29,6 +29,9 @@ package analysis var builtinPerCommandSymbols = map[string][]string{ "awk": { "bufio.NewScanner", // 🟢 line-by-line record reading; no write or exec capability. + "bufio.Scanner", // 🟢 scanner type retained for incremental getline state; no write or exec capability. + "bytes.Buffer", // 🟢 in-memory command pipe buffer; no filesystem/network/exec side effects. + "bytes.NewReader", // 🟢 wraps buffered command-pipe bytes as stdin; pure in-memory, no I/O. "context.Context", // 🟢 deadline/cancellation plumbing; pure interface, no side effects. "errors.Is", // 🟢 error comparison; pure function, no I/O. "errors.New", // 🟢 creates a simple error value; pure function, no I/O. @@ -37,6 +40,7 @@ var builtinPerCommandSymbols = map[string][]string{ "io.EOF", // 🟢 sentinel error value; pure constant. "io.NopCloser", // 🟢 wraps a Reader with a no-op Close; no side effects. "io.ReadCloser", // 🟢 interface type; no side effects. + "io.Reader", // 🟢 interface type for command-pipe stdin; no side effects. "math/big.Float", // 🟢 arbitrary-precision float type used to convert large awk printf integers; pure in-memory arithmetic. "math/big.Int", // 🟢 arbitrary-precision integer type used for large awk printf integers; pure in-memory arithmetic. "math/big.NewInt", // 🟢 constructs an in-memory integer value; pure function, no I/O. diff --git a/builtins/awk/ast.go b/builtins/awk/ast.go index e7f7c6ac..5c9ab403 100644 --- a/builtins/awk/ast.go +++ b/builtins/awk/ast.go @@ -6,7 +6,14 @@ package awk type program struct { - rules []rule + rules []rule + functions map[string]*functionDef +} + +type functionDef struct { + name string + params []string + body []stmt } type ruleKind int @@ -29,12 +36,14 @@ type stmt interface { type printStmt struct { args []expr + pipe expr } func (*printStmt) stmtNode() {} type printfStmt struct { args []expr + pipe expr } func (*printfStmt) stmtNode() {} @@ -43,6 +52,7 @@ type ifStmt struct { cond expr thenStmts []stmt elseStmts []stmt + endsBlock bool } func (*ifStmt) stmtNode() {} @@ -51,22 +61,25 @@ type forInStmt struct { varName string arrayName string body []stmt + endsBlock bool } func (*forInStmt) stmtNode() {} type forStmt struct { - init expr - cond expr - post expr - body []stmt + init expr + cond expr + post expr + body []stmt + endsBlock bool } func (*forStmt) stmtNode() {} type whileStmt struct { - cond expr - body []stmt + cond expr + body []stmt + endsBlock bool } func (*whileStmt) stmtNode() {} @@ -75,6 +88,18 @@ type nextStmt struct{} func (*nextStmt) stmtNode() {} +type exitStmt struct { + status expr +} + +func (*exitStmt) stmtNode() {} + +type returnStmt struct { + value expr +} + +func (*returnStmt) stmtNode() {} + type breakStmt struct{} func (*breakStmt) stmtNode() {} @@ -84,9 +109,9 @@ type continueStmt struct{} func (*continueStmt) stmtNode() {} type deleteStmt struct { - name string - index expr - all bool + name string + indices []expr + all bool } func (*deleteStmt) stmtNode() {} @@ -127,12 +152,18 @@ type varExpr struct { func (*varExpr) exprNode() {} type arrayRefExpr struct { - name string - index expr + name string + indices []expr } func (*arrayRefExpr) exprNode() {} +type compositeExpr struct { + parts []expr +} + +func (*compositeExpr) exprNode() {} + type fieldExpr struct { index expr } @@ -160,6 +191,14 @@ type binaryExpr struct { func (*binaryExpr) exprNode() {} +type ternaryExpr struct { + cond expr + then expr + els expr +} + +func (*ternaryExpr) exprNode() {} + type rangeExpr struct { start expr end expr @@ -189,3 +228,19 @@ type callExpr struct { } func (*callExpr) exprNode() {} + +type getlineSourceKind int + +const ( + getlineMain getlineSourceKind = iota + getlineFile + getlineCommand +) + +type getlineExpr struct { + target expr + source expr + kind getlineSourceKind +} + +func (*getlineExpr) exprNode() {} diff --git a/builtins/awk/awk.go b/builtins/awk/awk.go index bd4cc05b..7c33ed4b 100644 --- a/builtins/awk/awk.go +++ b/builtins/awk/awk.go @@ -14,15 +14,22 @@ // This implements a practical, intentionally restricted awk profile: program // loading from an inline argument or -f files, -F field // separators, -v scalar variables, BEGIN/main/END rules, print and printf, -// scalar and associative array assignment, if/else, for/while loops, next, -// arithmetic/comparison/boolean expressions, regex patterns and match -// operators, regex field separators, string concatenation, scalar built-in -// functions, split, delete, ENVIRON, and field/built-in variables such as $0, -// $1, NF, NR, FNR, FILENAME, FS, OFS, and ORS. +// scalar and associative array assignment, composite array keys, if/else, +// for/while loops, next, exit, arithmetic/comparison/boolean/ternary +// expressions, regex patterns and match operators, regex field separators, +// string concatenation, scalar built-in functions, split, sub, gsub, gensub, +// match, sprintf, strtonum, asorti, delete, ENVIRON, IGNORECASE, +// user-defined functions with return and scalar or +// array parameters, and field/built-in variables such as $0, $1, NF, NR, FNR, +// FILENAME, FS, RS, OFS, ORS, SUBSEP, RSTART, and RLENGTH. // -// Blocked or deferred features include system(), command pipes, output -// redirection, getline, user-defined functions, and many additional POSIX/GNU -// awk builtins. +// Command strings in awk pipes are parsed and executed by rshell under the +// active sandbox. Blocked or deferred features include system(), awk file +// output redirection, +// ARGV/ARGC, BEGINFILE/ENDFILE, +// nextfile, include/load, namespaces, FIELDWIDTHS/FPAT/CSV mode, introspection +// variables such as PROCINFO/SYMTAB/FUNCTAB, indirect calls, and many +// additional POSIX/GNU awk builtins. package awk import ( @@ -139,7 +146,26 @@ func registerFlags(fs *builtins.FlagSet) builtins.HandlerFunc { func printHelp(callCtx *builtins.CallContext, fs *builtins.FlagSet) { callCtx.Out("Usage: awk [OPTION]... 'program' [FILE]...\n") callCtx.Out("Pattern scanning and text processing.\n") + callCtx.Out("This is a practical rshell awk profile, not a full GNU awk clone.\n") callCtx.Out("With no FILE, or when FILE is -, read standard input.\n\n") + + callCtx.Out("Supported profile:\n") + callCtx.Out(" - Inline programs, -f program files, -F separators, -v assignments, FILE args, and - for stdin.\n") + callCtx.Out(" - BEGIN/main/END rules; regex, comparison, boolean, and range patterns.\n") + callCtx.Out(" - Fields and records: $0, $1..$NF, NF, NR, FNR, FILENAME, FS, RS, OFS, ORS, SUBSEP, RSTART, RLENGTH.\n") + callCtx.Out(" - Scalars, associative arrays, composite keys, ENVIRON, IGNORECASE, arithmetic, comparisons, regex match, ternary, and string concatenation.\n") + callCtx.Out(" - if/else, for, for-in, while, break, continue, next, exit, and user-defined functions with return.\n") + callCtx.Out(" - print, printf, sprintf, length, substr, index, tolower, toupper, int, split, sub, gsub, gensub, match, strtonum, asorti, delete, and close.\n") + callCtx.Out(" - Output command pipes such as print x | \"sort\" and rshell command strings such as print x | \"cat | sort\".\n") + callCtx.Out(" - getline, getline var, getline var < file, and \"cmd\" | getline var; file reads use rshell path policy and command strings run through rshell.\n\n") + + callCtx.Out("Not supported:\n") + callCtx.Out(" - system(). Use supported awk command pipes/getline pipes instead; command strings run through rshell and its active sandbox.\n") + callCtx.Out(" - print/printf file output redirection to file targets, such as print x > \"file\" or printf ... >> \"file\". Output command pipes remain supported and their command strings follow normal rshell policy.\n") + callCtx.Out(" - ARGV/ARGC mutation, BEGINFILE/ENDFILE, nextfile, do/while, switch, include/load, namespaces, and indirect function calls.\n") + callCtx.Out(" - GNU awk CSV mode, FIELDWIDTHS, FPAT, PROCINFO, SYMTAB, FUNCTAB, typed regexps, and extension loading.\n") + callCtx.Out(" - Many GNU/POSIX utility builtins are intentionally absent, including asort, patsplit, math/time/random helpers, bitwise, typeof, and i18n functions.\n\n") + fs.SetOutput(callCtx.Stdout) fs.PrintDefaults() } diff --git a/builtins/awk/eval.go b/builtins/awk/eval.go index 8f544e65..0a3e999e 100644 --- a/builtins/awk/eval.go +++ b/builtins/awk/eval.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "math" + "strconv" "strings" ) @@ -17,11 +18,38 @@ var errNextRecord = errors.New("next record") var errBreakLoop = errors.New("break loop") var errContinueLoop = errors.New("continue loop") +type exitError struct { + code int +} + +func (e *exitError) Error() string { + return "exit" +} + +type returnError struct { + value value +} + +func (e *returnError) Error() string { + return "return" +} + func (rt *runtime) execStatements(ctx context.Context, stmts []stmt) error { - for _, st := range stmts { + return rt.execStatementsWithFuture(ctx, stmts, nil) +} + +func (rt *runtime) execStatementsWithFuture(ctx context.Context, stmts []stmt, future []stmt) error { + prevCtx := rt.ctx + rt.ctx = ctx + defer func() { rt.ctx = prevCtx }() + prevFuture := rt.futureStmts + defer func() { rt.futureStmts = prevFuture }() + for i, st := range stmts { if err := ctx.Err(); err != nil { return err } + remaining := stmtFuture(stmts[i+1:], future) + rt.futureStmts = remaining switch s := st.(type) { case *printStmt: vals := make([]value, 0, len(s.args)) @@ -36,7 +64,8 @@ func (rt *runtime) execStatements(ctx context.Context, stmts []stmt) error { vals = append(vals, v) } } - if err := rt.printValues(vals); err != nil { + out := rt.formatPrintValues(vals) + if err := rt.writeOutput(ctx, s.pipe, out, remaining); err != nil { return err } case *printfStmt: @@ -55,18 +84,20 @@ func (rt *runtime) execStatements(ctx context.Context, stmts []stmt) error { if err != nil { return err } - rt.callCtx.Out(out) + if err := rt.writeOutput(ctx, s.pipe, out, remaining); err != nil { + return err + } case *ifStmt: cond, err := rt.eval(s.cond) if err != nil { return err } if cond.Bool() { - if err := rt.execStatements(ctx, s.thenStmts); err != nil { + if err := rt.execStatementsWithFuture(ctx, s.thenStmts, remaining); err != nil { return err } } else if len(s.elseStmts) > 0 { - if err := rt.execStatements(ctx, s.elseStmts); err != nil { + if err := rt.execStatementsWithFuture(ctx, s.elseStmts, remaining); err != nil { return err } } @@ -79,7 +110,7 @@ func (rt *runtime) execStatements(ctx context.Context, stmts []stmt) error { if err := rt.setVar(s.varName, stringValue(key)); err != nil { return err } - if err := rt.execStatements(ctx, s.body); err != nil { + if err := rt.execStatementsWithFuture(ctx, s.body, stmtFuture(s.body, remaining)); err != nil { if errors.Is(err, errBreakLoop) { break } @@ -90,15 +121,35 @@ func (rt *runtime) execStatements(ctx context.Context, stmts []stmt) error { } } case *forStmt: - if err := rt.execFor(ctx, s); err != nil { + if err := rt.execFor(ctx, s, remaining); err != nil { return err } case *whileStmt: - if err := rt.execWhile(ctx, s); err != nil { + if err := rt.execWhile(ctx, s, remaining); err != nil { return err } case *nextStmt: return errNextRecord + case *exitStmt: + code := rt.exitCode + if s.status != nil { + status, err := rt.eval(s.status) + if err != nil { + return err + } + code = int(status.Number()) + } + rt.exitCode = code + return &exitError{code: code} + case *returnStmt: + if s.value == nil { + return &returnError{value: unassignedValue()} + } + v, err := rt.eval(s.value) + if err != nil { + return err + } + return &returnError{value: v} case *breakStmt: return errBreakLoop case *continueStmt: @@ -110,11 +161,11 @@ func (rt *runtime) execStatements(ctx context.Context, stmts []stmt) error { } continue } - key, err := rt.eval(s.index) + key, err := rt.evalArrayKey(s.indices) if err != nil { return err } - if err := rt.deleteArrayElem(s.name, key.String()); err != nil { + if err := rt.deleteArrayElem(s.name, key); err != nil { return err } case *exprStmt: @@ -128,7 +179,20 @@ func (rt *runtime) execStatements(ctx context.Context, stmts []stmt) error { return nil } -func (rt *runtime) execFor(ctx context.Context, s *forStmt) error { +func stmtFuture(remaining, future []stmt) []stmt { + if len(remaining) == 0 { + return future + } + if len(future) == 0 { + return remaining + } + out := make([]stmt, 0, len(remaining)+len(future)) + out = append(out, remaining...) + out = append(out, future...) + return out +} + +func (rt *runtime) execFor(ctx context.Context, s *forStmt, future []stmt) error { if s.init != nil { if _, err := rt.eval(s.init); err != nil { return err @@ -147,7 +211,7 @@ func (rt *runtime) execFor(ctx context.Context, s *forStmt) error { return nil } } - err := rt.execStatements(ctx, s.body) + err := rt.execStatementsWithFuture(ctx, s.body, stmtFuture(s.body, future)) if errors.Is(err, errBreakLoop) { return nil } @@ -162,7 +226,7 @@ func (rt *runtime) execFor(ctx context.Context, s *forStmt) error { } } -func (rt *runtime) execWhile(ctx context.Context, s *whileStmt) error { +func (rt *runtime) execWhile(ctx context.Context, s *whileStmt, future []stmt) error { for { if err := ctx.Err(); err != nil { return err @@ -174,7 +238,7 @@ func (rt *runtime) execWhile(ctx context.Context, s *whileStmt) error { if !cond.Bool() { return nil } - err = rt.execStatements(ctx, s.body) + err = rt.execStatementsWithFuture(ctx, s.body, stmtFuture(s.body, future)) if errors.Is(err, errBreakLoop) { return nil } @@ -208,13 +272,22 @@ func substrEnd(start, length int, count float64) int { } func (rt *runtime) printValues(vals []value) error { + return rt.writeStdoutString(rt.ctx, rt.formatPrintValues(vals), nil) +} + +func (rt *runtime) formatPrintValues(vals []value) string { parts := make([]string, len(vals)) for i, v := range vals { parts[i] = v.String() } - rt.callCtx.Out(strings.Join(parts, rt.getVar("OFS").String())) - rt.callCtx.Out(rt.getVar("ORS").String()) - return nil + return strings.Join(parts, rt.getVar("OFS").String()) + rt.getVar("ORS").String() +} + +func (rt *runtime) writeOutput(ctx context.Context, pipe expr, out string, remaining []stmt) error { + if pipe == nil { + return rt.writeStdoutString(ctx, out, remaining) + } + return rt.writeCommandPipe(ctx, pipe, out) } func (rt *runtime) eval(x expr) (value, error) { @@ -224,7 +297,7 @@ func (rt *runtime) eval(x expr) (value, error) { case *stringExpr: return stringValue(e.value), nil case *regexExpr: - re, err := compileRegex(e.pattern) + re, err := rt.compileRegex(e.pattern) if err != nil { return value{}, err } @@ -236,6 +309,12 @@ func (rt *runtime) eval(x expr) (value, error) { return rt.getVar(e.name), nil case *arrayRefExpr: return rt.evalArrayRef(e) + case *compositeExpr: + key, err := rt.evalArrayKey(e.parts) + if err != nil { + return value{}, err + } + return stringValue(key), nil case *fieldExpr: v, err := rt.eval(e.index) if err != nil { @@ -268,10 +347,21 @@ func (rt *runtime) eval(x expr) (value, error) { } case *binaryExpr: return rt.evalBinary(e) + case *ternaryExpr: + cond, err := rt.eval(e.cond) + if err != nil { + return value{}, err + } + if cond.Bool() { + return rt.eval(e.then) + } + return rt.eval(e.els) case *assignExpr: return rt.evalAssign(e) case *incDecExpr: return rt.evalIncDec(e) + case *getlineExpr: + return rt.evalGetline(e) case *callExpr: return rt.evalCall(e) default: @@ -280,9 +370,30 @@ func (rt *runtime) eval(x expr) (value, error) { } func (rt *runtime) evalCall(e *callExpr) (value, error) { + if fn, ok := rt.prog.functions[e.name]; ok { + return rt.evalUserFunction(fn, e.args) + } if e.name == "split" { return rt.evalSplit(e) } + if e.name == "sub" || e.name == "gsub" { + return rt.evalSubstitution(e) + } + if e.name == "match" { + return rt.evalMatch(e) + } + if e.name == "gensub" { + return rt.evalGensub(e) + } + if e.name == "length" { + return rt.evalLength(e) + } + if e.name == "close" { + return rt.evalClose(e) + } + if e.name == "asorti" { + return rt.evalAsorti(e) + } args := make([]value, 0, len(e.args)) for _, arg := range e.args { v, err := rt.eval(arg) @@ -295,12 +406,6 @@ func (rt *runtime) evalCall(e *callExpr) (value, error) { return value{}, err } switch e.name { - case "length": - s := rt.field(0).String() - if len(args) == 1 { - s = args[0].String() - } - return numberValue(float64(len([]rune(s)))), nil case "substr": s := []rune(args[0].String()) start := substrStart(args[1].Number(), len(s)) @@ -332,9 +437,674 @@ func (rt *runtime) evalCall(e *callExpr) (value, error) { case "int": v := args[0] return numberValue(math.Trunc(v.Number())), nil + case "strtonum": + return numberValue(parseAwkNumberLiteral(args[0].String())), nil + case "sprintf": + out, err := formatPrintf(args[0].String(), args[1:]) + if err != nil { + return value{}, err + } + return stringValue(out), nil default: - return value{}, fmt.Errorf("function calls are not supported") + if _, ok := unsupportedBuiltinFunctions[e.name]; ok { + return value{}, fmt.Errorf("function calls are not supported") + } + return value{}, fmt.Errorf("function %q not defined", e.name) + } +} + +func (rt *runtime) evalClose(e *callExpr) (value, error) { + if err := validateBuiltinCallArity(e.name, len(e.args)); err != nil { + return value{}, err + } + command, err := rt.eval(e.args[0]) + if err != nil { + return value{}, err + } + status, ok, err := rt.closeCommandPipe(rt.ctx, command.String(), true) + if err != nil { + return value{}, err } + if ok { + return numberValue(float64(status)), nil + } + status, ok, err = rt.closeCommandInput(command.String()) + if err != nil { + return value{}, err + } + if ok { + return numberValue(float64(status)), nil + } + if status, ok := rt.closeInputFile(command.String()); ok { + return numberValue(float64(status)), nil + } + rt.setErrnoString("close of redirection that was never opened") + return numberValue(-1), nil +} + +func (rt *runtime) evalGetline(e *getlineExpr) (value, error) { + var target assignTarget + hasTarget := e.target != nil + if hasTarget { + resolved, _, err := rt.resolveAssignable(e.target) + if err != nil { + return value{}, err + } + target = resolved + } + + rec, status, err := rt.readGetlineRecord(e) + if err != nil { + return value{}, err + } + if status != 1 { + return numberValue(float64(status)), nil + } + if hasTarget { + if err := rt.setResolvedAssignable(target, inputStringValue(rec)); err != nil { + return value{}, err + } + return numberValue(1), nil + } + if err := rt.setRecord(rec); err != nil { + return value{}, err + } + return numberValue(1), nil +} + +func (rt *runtime) readGetlineRecord(e *getlineExpr) (string, int, error) { + switch e.kind { + case getlineMain: + rec, ok, err := rt.readMainRecord(rt.ctx) + if err != nil { + return "", 0, err + } + if !ok { + return "", 0, nil + } + return rec, 1, nil + case getlineFile: + source, err := rt.eval(e.source) + if err != nil { + return "", 0, err + } + return rt.getlineFileRecord(rt.ctx, source.String()) + case getlineCommand: + source, err := rt.eval(e.source) + if err != nil { + return "", 0, err + } + return rt.getlineCommandRecord(rt.ctx, source.String()) + default: + return "", 0, fmt.Errorf("unknown getline source") + } +} + +func (rt *runtime) evalLength(e *callExpr) (value, error) { + if err := validateBuiltinCallArity(e.name, len(e.args)); err != nil { + return value{}, err + } + if len(e.args) == 0 { + return numberValue(float64(len([]rune(rt.field(0).String())))), nil + } + if arg, ok := e.args[0].(*varExpr); ok { + rt.ensureBuiltinArray(arg.name) + if rt.isArray(arg.name) { + keys, err := rt.arrayKeys(arg.name) + if err != nil { + return value{}, err + } + return numberValue(float64(len(keys))), nil + } + } + v, err := rt.eval(e.args[0]) + if err != nil { + return value{}, err + } + return numberValue(float64(len([]rune(v.String())))), nil +} + +type functionArg struct { + value value + valueSet bool + arrayAlias *localVar + globalArrayName string +} + +func (rt *runtime) evalUserFunction(fn *functionDef, args []expr) (value, error) { + if len(args) > len(fn.params) { + return value{}, fmt.Errorf("function %q called with too many arguments", fn.name) + } + callArgs := make([]functionArg, len(args)) + for i, arg := range args { + v, err := rt.evalFunctionArg(arg) + if err != nil { + return value{}, err + } + callArgs[i] = v + } + frame := callFrame{locals: make(map[string]*localVar, len(fn.params))} + for _, param := range fn.params { + frame.locals[param] = &localVar{} + } + globalAliases := make(map[string]*localVar) + rt.frames = append(rt.frames, frame) + defer rt.popFrame() + for i, arg := range callArgs { + local := rt.lookupLocal(fn.params[i]) + local.arrayAlias = arg.arrayAlias + if arg.globalArrayName != "" { + alias := globalAliases[arg.globalArrayName] + if alias == nil { + alias = &localVar{globalArrayName: arg.globalArrayName} + globalAliases[arg.globalArrayName] = alias + } + local.arrayAlias = alias + } + if arg.valueSet { + if err := rt.setLocalScalar(local, arg.value); err != nil { + return value{}, err + } + } + } + if rt.ctx == nil { + return value{}, fmt.Errorf("missing evaluation context") + } + err := rt.execStatementsWithFuture(rt.ctx, fn.body, rt.futureStmts) + if ret, ok := err.(*returnError); ok { + return ret.value, nil + } + if err != nil { + return value{}, err + } + return unassignedValue(), nil +} + +func (rt *runtime) evalFunctionArg(arg expr) (functionArg, error) { + if v, ok := arg.(*varExpr); ok { + return rt.evalVariableFunctionArg(v.name) + } + value, err := rt.eval(arg) + if err != nil { + return functionArg{}, err + } + return functionArg{value: value, valueSet: true}, nil +} + +func (rt *runtime) evalVariableFunctionArg(name string) (functionArg, error) { + if local := rt.lookupLocal(name); local != nil { + arg := functionArg{} + if local.valueSet { + arg.value = local.value + arg.valueSet = true + } + root := rootLocalVar(local) + if rt.localIsArray(root) || !local.valueSet { + arg.arrayAlias = root + } + return arg, nil + } + if rt.isGlobalArray(name) { + return functionArg{globalArrayName: name}, nil + } + if v, ok := rt.vars[name]; ok { + return functionArg{value: v, valueSet: true}, nil + } + if isBuiltinArrayName(name) { + return functionArg{globalArrayName: name}, nil + } + if isBuiltinScalarName(name) { + return functionArg{value: rt.getVar(name), valueSet: true}, nil + } + return functionArg{globalArrayName: name}, nil +} + +func (rt *runtime) popFrame() { + if len(rt.frames) == 0 { + return + } + frame := rt.frames[len(rt.frames)-1] + rt.frames = rt.frames[:len(rt.frames)-1] + for _, local := range frame.locals { + rt.varBytes -= local.valueSize + if local.arrayAlias != nil || local.globalArrayName != "" { + continue + } + for _, size := range local.arraySizes { + rt.varBytes -= size + } + } + if rt.varBytes < 0 { + rt.varBytes = 0 + } +} + +func (rt *runtime) evalSubstitution(e *callExpr) (value, error) { + if err := validateBuiltinCallArity(e.name, len(e.args)); err != nil { + return value{}, err + } + re, err := rt.compileRegexArg(e.args[0]) + if err != nil { + return value{}, err + } + repl, err := rt.eval(e.args[1]) + if err != nil { + return value{}, err + } + var target assignTarget + var current value + if len(e.args) == 3 { + target, current, err = rt.resolveAssignable(e.args[2]) + if err != nil { + return value{}, err + } + } else { + target = assignTarget{field: true, fieldIndex: 0} + current = rt.field(0) + } + next, count, err := substituteAwk(re, current.String(), repl.String(), e.name == "gsub") + if err != nil { + return value{}, err + } + if count == 0 { + return numberValue(0), nil + } + if err := rt.setResolvedAssignable(target, stringValue(next)); err != nil { + return value{}, err + } + return numberValue(float64(count)), nil +} + +func (rt *runtime) evalMatch(e *callExpr) (value, error) { + if err := validateBuiltinCallArity(e.name, len(e.args)); err != nil { + return value{}, err + } + var captures *varExpr + if len(e.args) == 3 { + var ok bool + captures, ok = e.args[2].(*varExpr) + if !ok { + return value{}, fmt.Errorf("match capture destination must be an array variable") + } + } + input, err := rt.eval(e.args[0]) + if err != nil { + return value{}, err + } + re, err := rt.compileRegexArg(e.args[1]) + if err != nil { + return value{}, err + } + if captures != nil { + if err := rt.deleteArray(captures.name); err != nil { + return value{}, err + } + } + text := input.String() + match := re.FindStringRuneIndex(text) + if match == nil { + if err := rt.setVar("RSTART", numberValue(0)); err != nil { + return value{}, err + } + if err := rt.setVar("RLENGTH", numberValue(-1)); err != nil { + return value{}, err + } + return numberValue(0), nil + } + start := match[0] + 1 + length := match[1] - match[0] + if err := rt.setVar("RSTART", numberValue(float64(start))); err != nil { + return value{}, err + } + if err := rt.setVar("RLENGTH", numberValue(float64(length))); err != nil { + return value{}, err + } + if captures != nil { + if err := rt.setMatchCaptures(captures.name, text, re); err != nil { + return value{}, err + } + } + return numberValue(float64(start)), nil +} + +func (rt *runtime) setMatchCaptures(name, text string, re *awkRegex) error { + locs := re.FindStringSubmatchIndex(text) + for i := 0; i+1 < len(locs); i += 2 { + key := fmt.Sprintf("%d", i/2) + value := "" + if locs[i] >= 0 { + value = text[locs[i]:locs[i+1]] + } + if err := rt.setArrayElem(name, key, inputStringValue(value)); err != nil { + return err + } + } + return nil +} + +func (rt *runtime) evalGensub(e *callExpr) (value, error) { + if err := validateBuiltinCallArity(e.name, len(e.args)); err != nil { + return value{}, err + } + re, err := rt.compileRegexArg(e.args[0]) + if err != nil { + return value{}, err + } + repl, err := rt.eval(e.args[1]) + if err != nil { + return value{}, err + } + how, err := rt.eval(e.args[2]) + if err != nil { + return value{}, err + } + target := rt.field(0) + if len(e.args) == 4 { + target, err = rt.eval(e.args[3]) + if err != nil { + return value{}, err + } + } + out, err := gensubAwk(re, target.String(), repl.String(), how) + if err != nil { + return value{}, err + } + return stringValue(out), nil +} + +func (rt *runtime) evalAsorti(e *callExpr) (value, error) { + if err := validateBuiltinCallArity(e.name, len(e.args)); err != nil { + return value{}, err + } + source, ok := e.args[0].(*varExpr) + if !ok { + return value{}, fmt.Errorf("asorti source must be an array variable") + } + destName := source.name + if len(e.args) == 2 { + dest, ok := e.args[1].(*varExpr) + if !ok { + return value{}, fmt.Errorf("asorti destination must be an array variable") + } + destName = dest.name + } + keys, err := rt.arrayKeysSorted(source.name, rt.ignoreCase()) + if err != nil { + return value{}, err + } + elems := make(map[string]value, len(keys)) + for i, key := range keys { + elems[fmt.Sprintf("%d", i+1)] = inputStringValue(key) + } + if err := rt.replaceArray(destName, elems); err != nil { + return value{}, err + } + return numberValue(float64(len(keys))), nil +} + +func (rt *runtime) compileRegexArg(x expr) (*awkRegex, error) { + if rx, ok := x.(*regexExpr); ok { + return rt.compileRegex(rx.pattern) + } + v, err := rt.eval(x) + if err != nil { + return nil, err + } + return rt.compileRegex(v.String()) +} + +func substituteAwk(re *awkRegex, input, replacement string, all bool) (string, int, error) { + var matches [][]int + if all { + matches = re.FindAllStringIndex(input, -1) + } else if loc := re.FindStringIndex(input); loc != nil { + matches = [][]int{loc} + } + if len(matches) == 0 { + return input, 0, nil + } + + var b strings.Builder + last := 0 + for _, loc := range matches { + start := loc[0] + end := loc[1] + if err := appendLimitedString(&b, input[last:start]); err != nil { + return "", 0, err + } + if err := appendAwkReplacement(&b, replacement, input[start:end]); err != nil { + return "", 0, err + } + last = end + } + if err := appendLimitedString(&b, input[last:]); err != nil { + return "", 0, err + } + return b.String(), len(matches), nil +} + +func gensubAwk(re *awkRegex, input, replacement string, how value) (string, error) { + locs := re.FindAllStringSubmatchIndex(input, -1) + if len(locs) == 0 { + return input, nil + } + global := false + nth := int(how.Number()) + howString := how.String() + if hasLeadingG(howString) { + global = true + nth = 1 + } + if nth < 1 { + nth = 1 + } + + var b strings.Builder + last := 0 + seen := 0 + for _, loc := range locs { + if loc[0] == loc[1] && loc[0] == last && seen > 0 { + continue + } + seen++ + replace := global || seen == nth + if !replace { + continue + } + if err := appendLimitedString(&b, input[last:loc[0]]); err != nil { + return "", err + } + if err := appendGensubReplacement(&b, replacement, input, loc); err != nil { + return "", err + } + last = loc[1] + if !global { + break + } + } + if last == 0 && !(global || seen >= nth) { + return input, nil + } + if err := appendLimitedString(&b, input[last:]); err != nil { + return "", err + } + return b.String(), nil +} + +func hasLeadingG(s string) bool { + return len(s) > 0 && (s[0] == 'g' || s[0] == 'G') +} + +func appendGensubReplacement(b *strings.Builder, replacement, input string, loc []int) error { + for i := 0; i < len(replacement); i++ { + switch replacement[i] { + case '&': + if err := appendSubmatch(b, input, loc, 0); err != nil { + return err + } + case '\\': + if i+1 >= len(replacement) { + if err := appendLimitedString(b, `\`); err != nil { + return err + } + continue + } + next := replacement[i+1] + i++ + if next >= '0' && next <= '9' { + if err := appendSubmatch(b, input, loc, int(next-'0')); err != nil { + return err + } + continue + } + if err := appendLimitedString(b, string(next)); err != nil { + return err + } + default: + if err := appendLimitedString(b, replacement[i:i+1]); err != nil { + return err + } + } + } + return nil +} + +func appendSubmatch(b *strings.Builder, input string, loc []int, group int) error { + i := group * 2 + if i+1 >= len(loc) || loc[i] < 0 { + return nil + } + return appendLimitedString(b, input[loc[i]:loc[i+1]]) +} + +func appendAwkReplacement(b *strings.Builder, replacement, matched string) error { + for i := 0; i < len(replacement); i++ { + switch replacement[i] { + case '&': + if err := appendLimitedString(b, matched); err != nil { + return err + } + case '\\': + if i+1 >= len(replacement) { + if err := appendLimitedString(b, `\`); err != nil { + return err + } + continue + } + next := replacement[i+1] + i++ + if next == '&' || next == '\\' { + if err := appendLimitedString(b, string(next)); err != nil { + return err + } + continue + } + if err := appendLimitedString(b, `\`+string(next)); err != nil { + return err + } + default: + if err := appendLimitedString(b, replacement[i:i+1]); err != nil { + return err + } + } + } + return nil +} + +func parseAwkNumberLiteral(s string) float64 { + text := strings.TrimSpace(s) + if text == "" { + return 0 + } + sign := 1.0 + if text[0] == '+' || text[0] == '-' { + if text[0] == '-' { + sign = -1 + } + text = text[1:] + } + if len(text) > 2 && text[0] == '0' && (text[1] == 'x' || text[1] == 'X') { + if n, ok := parseUnsignedBasePrefix(text[2:], 16); ok { + return sign * float64(n) + } + return 0 + } + if shouldParseAwkOctalPrefix(text) { + if n, ok := parseUnsignedBasePrefix(text[1:], 8); ok { + return sign * float64(n) + } + return 0 + } + prefix := numericPrefix(text) + if prefix == "" { + return 0 + } + if n, err := strconv.ParseFloat(prefix, 64); err == nil { + return sign * n + } + return 0 +} + +func shouldParseAwkOctalPrefix(s string) bool { + if len(s) <= 1 || s[0] != '0' || s[1] < '0' || s[1] > '7' { + return false + } + for i := 1; i < len(s); i++ { + ch := s[i] + switch { + case ch >= '0' && ch <= '7': + continue + case ch == '.' || ch == 'e' || ch == 'E' || ch == '8' || ch == '9': + return false + default: + return true + } + } + return true +} + +func parseUnsignedBasePrefix(s string, base int) (uint64, bool) { + if s == "" { + return 0, false + } + var n uint64 + for i := 0; i < len(s); i++ { + digit, ok := digitValue(s[i]) + if !ok || digit >= base { + return n, i > 0 + } + n = n*uint64(base) + uint64(digit) + } + return n, true +} + +func digitValue(ch byte) (int, bool) { + switch { + case ch >= '0' && ch <= '9': + return int(ch - '0'), true + case ch >= 'a' && ch <= 'f': + return int(ch-'a') + 10, true + case ch >= 'A' && ch <= 'F': + return int(ch-'A') + 10, true + default: + return 0, false + } +} + +func appendLimitedString(b *strings.Builder, s string) error { + if len(s) > MaxVariableBytes-b.Len() { + return fmt.Errorf("replacement output exceeds %d bytes", MaxVariableBytes) + } + b.WriteString(s) + return nil +} + +func runeLen(s string) int { + n := 0 + for range s { + n++ + } + return n } func (rt *runtime) evalSplit(e *callExpr) (value, error) { @@ -370,15 +1140,15 @@ func (rt *runtime) evalSplit(e *callExpr) (value, error) { parts = splitAwkChars(input.String()) } else if regexSplit || sep != " " { if regexSplit { - parts, err = splitAwkRegex(input.String(), sep) + parts, err = rt.splitAwkRegex(input.String(), sep) } else { - parts, err = splitAwkFields(input.String(), sep) + parts, err = rt.splitAwkFields(input.String(), sep) } if err != nil { return value{}, err } } else { - parts, err = splitAwkFields(input.String(), sep) + parts, err = rt.splitAwkFields(input.String(), sep) if err != nil { return value{}, err } @@ -485,7 +1255,7 @@ func (rt *runtime) evalBinary(e *binaryExpr) (value, error) { func (rt *runtime) matchRegexExpr(left value, rightExpr expr) (bool, error) { if rx, ok := rightExpr.(*regexExpr); ok { - re, err := compileRegex(rx.pattern) + re, err := rt.compileRegex(rx.pattern) if err != nil { return false, err } @@ -495,7 +1265,7 @@ func (rt *runtime) matchRegexExpr(left value, rightExpr expr) (bool, error) { if err != nil { return false, err } - re, err := compileRegex(right.String()) + re, err := rt.compileRegex(right.String()) if err != nil { return false, err } @@ -580,11 +1350,10 @@ func (rt *runtime) resolveAssignable(x expr) (assignTarget, value, error) { } return assignTarget{name: v.name}, rt.getVar(v.name), nil case *arrayRefExpr: - key, err := rt.eval(v.index) + keyString, err := rt.evalArrayKey(v.indices) if err != nil { return assignTarget{}, value{}, err } - keyString := key.String() current, err := rt.getArrayElem(v.name, keyString) if err != nil { return assignTarget{}, value{}, err @@ -626,11 +1395,29 @@ func (rt *runtime) currentResolvedAssignable(target assignTarget) (value, error) } func (rt *runtime) evalArrayRef(ref *arrayRefExpr) (value, error) { - key, err := rt.eval(ref.index) + key, err := rt.evalArrayKey(ref.indices) if err != nil { return value{}, err } - return rt.getArrayElem(ref.name, key.String()) + return rt.getArrayElem(ref.name, key) +} + +func (rt *runtime) evalArrayKey(indices []expr) (string, error) { + if len(indices) == 0 { + return "", fmt.Errorf("array index is required") + } + parts := make([]string, len(indices)) + for i, index := range indices { + v, err := rt.eval(index) + if err != nil { + return "", err + } + parts[i] = v.String() + } + if len(parts) == 1 { + return parts[0], nil + } + return strings.Join(parts, rt.getVar("SUBSEP").String()), nil } func boolValue(ok bool) value { diff --git a/builtins/awk/lexer.go b/builtins/awk/lexer.go index 614d6548..29ecee93 100644 --- a/builtins/awk/lexer.go +++ b/builtins/awk/lexer.go @@ -28,6 +28,8 @@ const ( tokRBracket tokSemicolon tokComma + tokQuestion + tokColon tokDollar tokAssign tokPlus @@ -130,6 +132,10 @@ func (l *lexer) next() (token, error) { return token{kind: tokSemicolon, lit: ";", pos: start}, nil case ',': return token{kind: tokComma, lit: ",", pos: start}, nil + case '?': + return token{kind: tokQuestion, lit: "?", pos: start}, nil + case ':': + return token{kind: tokColon, lit: ":", pos: start}, nil case '$': return token{kind: tokDollar, lit: "$", pos: start}, nil case '~': @@ -156,12 +162,12 @@ func (l *lexer) next() (token, error) { } return token{kind: tokStar, lit: "*", pos: start}, nil case '/': - if l.match('=') { - return token{kind: tokSlashAssign, lit: "/=", pos: start}, nil - } if canStartRegex(l.last, l.lastLit) { return l.scanRegex(start) } + if l.match('=') { + return token{kind: tokSlashAssign, lit: "/=", pos: start}, nil + } return token{kind: tokSlash, lit: "/", pos: start}, nil case '%': if l.match('=') { @@ -279,6 +285,15 @@ func (l *lexer) scanString(start int) (token, error) { if l.pos >= len(l.src) { return token{}, fmt.Errorf("unterminated string escape") } + if isOctalDigit(l.src[l.pos]) { + value := 0 + for digits := 0; digits < 3 && l.pos < len(l.src) && isOctalDigit(l.src[l.pos]); digits++ { + value = value*8 + int(l.src[l.pos]-'0') + l.pos++ + } + b.WriteByte(byte(value)) + continue + } esc := l.src[l.pos] l.pos++ b.WriteRune(decodeSimpleEscape(esc)) @@ -322,12 +337,15 @@ func (l *lexer) scanRegex(start int) (token, error) { } func canStartRegex(prev tokenKind, prevLit string) bool { - if prev == tokIdent && (prevLit == "print" || prevLit == "printf") { - return true + if prev == tokIdent { + switch prevLit { + case "print", "printf", "return", "exit": + return true + } } switch prev { case tokEOF, tokNewline, tokLBrace, tokRBrace, tokLParen, tokComma, tokSemicolon, - tokAssign, tokPlus, tokMinus, tokStar, tokSlash, tokPercent, tokBang, + tokQuestion, tokColon, tokAssign, tokPlus, tokMinus, tokStar, tokSlash, tokPercent, tokBang, tokLT, tokGT, tokLE, tokGE, tokEQ, tokNE, tokAnd, tokOr, tokMatch, tokNotMatch, tokPlusAssign, tokMinusAssign, tokStarAssign, tokSlashAssign, tokPercentAssign: @@ -349,6 +367,15 @@ func DecodeAwkEscapes(s string) string { b.WriteRune(r) continue } + if isOctalDigit(rune(s[0])) { + value := 0 + for digits := 0; digits < 3 && len(s) > 0 && isOctalDigit(rune(s[0])); digits++ { + value = value*8 + int(s[0]-'0') + s = s[1:] + } + b.WriteByte(byte(value)) + continue + } esc, escSize := utf8.DecodeRuneInString(s) s = s[escSize:] b.WriteRune(decodeSimpleEscape(esc)) @@ -356,6 +383,10 @@ func DecodeAwkEscapes(s string) string { return b.String() } +func isOctalDigit(ch rune) bool { + return ch >= '0' && ch <= '7' +} + func decodeSimpleEscape(esc rune) rune { switch esc { case 'n': diff --git a/builtins/awk/parser.go b/builtins/awk/parser.go index 34ce6050..48eeab9f 100644 --- a/builtins/awk/parser.go +++ b/builtins/awk/parser.go @@ -12,6 +12,7 @@ import ( const ( precAssign = 10 + precTernary = 15 precOr = 20 precAnd = 30 precCompare = 40 @@ -25,34 +26,26 @@ const ( var unsupportedBuiltinFunctions = map[string]struct{}{ "and": {}, "asort": {}, - "asorti": {}, "atan2": {}, "bindtextdomain": {}, - "close": {}, "compl": {}, "cos": {}, "dcgettext": {}, "dcngettext": {}, "exp": {}, "fflush": {}, - "gensub": {}, - "gsub": {}, "isarray": {}, "log": {}, "lshift": {}, - "match": {}, "mktime": {}, "or": {}, "patsplit": {}, "rand": {}, "rshift": {}, "sin": {}, - "sprintf": {}, "sqrt": {}, "srand": {}, "strftime": {}, - "strtonum": {}, - "sub": {}, "system": {}, "systime": {}, "typeof": {}, @@ -60,13 +53,21 @@ var unsupportedBuiltinFunctions = map[string]struct{}{ } var supportedBuiltinFunctions = map[string]struct{}{ - "index": {}, - "int": {}, - "length": {}, - "split": {}, - "substr": {}, - "tolower": {}, - "toupper": {}, + "close": {}, + "asorti": {}, + "gensub": {}, + "gsub": {}, + "index": {}, + "int": {}, + "length": {}, + "match": {}, + "split": {}, + "sprintf": {}, + "strtonum": {}, + "sub": {}, + "substr": {}, + "tolower": {}, + "toupper": {}, } type parser struct { @@ -81,9 +82,21 @@ func parseProgram(src string) (*program, error) { return nil, err } p := &parser{toks: toks} - prog := &program{} + prog := &program{functions: make(map[string]*functionDef)} p.skipSeparators() for !p.at(tokEOF) { + if p.atIdent("function") { + fn, err := p.parseFunctionDefinition() + if err != nil { + return nil, err + } + if _, exists := prog.functions[fn.name]; exists { + return nil, fmt.Errorf("function %q is already defined", fn.name) + } + prog.functions[fn.name] = fn + p.skipSeparators() + continue + } r, err := p.parseRule() if err != nil { return nil, err @@ -91,9 +104,64 @@ func parseProgram(src string) (*program, error) { prog.rules = append(prog.rules, r) p.skipSeparators() } + if err := validateLoopControlStatements(prog); err != nil { + return nil, err + } + if err := validateUserFunctionNameReferences(prog); err != nil { + return nil, err + } return prog, nil } +func (p *parser) parseFunctionDefinition() (*functionDef, error) { + p.advance() + if p.cur().kind != tokIdent { + return nil, fmt.Errorf("expected function name") + } + name := p.cur().lit + if err := validateFunctionName(name); err != nil { + return nil, err + } + p.advance() + if !p.match(tokLParen) { + return nil, fmt.Errorf("expected ( after function name") + } + params := []string{} + seen := make(map[string]int) + p.skipSeparators() + if !p.match(tokRParen) { + for { + p.skipSeparators() + if p.cur().kind != tokIdent { + return nil, fmt.Errorf("expected function parameter") + } + param := p.cur().lit + if err := validateFunctionParameterName(name, param); err != nil { + return nil, err + } + if first, ok := seen[param]; ok { + return nil, fmt.Errorf("function %q parameter #%d, %q, duplicates parameter #%d", name, len(params)+1, param, first) + } + seen[param] = len(params) + 1 + params = append(params, param) + p.advance() + p.skipSeparators() + if p.match(tokRParen) { + break + } + if !p.match(tokComma) { + return nil, fmt.Errorf("expected , or ) in function parameter list") + } + } + } + p.skipSeparators() + body, err := p.parseAction() + if err != nil { + return nil, err + } + return &functionDef{name: name, params: params, body: body}, nil +} + func (p *parser) parseRule() (rule, error) { if p.atIdent("BEGIN") { p.advance() @@ -147,7 +215,7 @@ func (p *parser) parseStatementList() ([]stmt, error) { return nil, err } stmts = append(stmts, st) - if !p.at(tokRBrace) && !p.at(tokEOF) && !isSeparator(p.cur().kind) { + if !p.at(tokRBrace) && !p.at(tokEOF) && !isSeparator(p.cur().kind) && !statementEndsBlock(st) { return nil, fmt.Errorf("expected statement separator") } p.skipSeparators() @@ -156,6 +224,21 @@ func (p *parser) parseStatementList() ([]stmt, error) { return stmts, nil } +func statementEndsBlock(st stmt) bool { + switch s := st.(type) { + case *ifStmt: + return s.endsBlock + case *forStmt: + return s.endsBlock + case *forInStmt: + return s.endsBlock + case *whileStmt: + return s.endsBlock + default: + return false + } +} + func (p *parser) parseStatement() (stmt, error) { if p.atIdent("if") { return p.parseIf() @@ -170,6 +253,12 @@ func (p *parser) parseStatement() (stmt, error) { p.advance() return &nextStmt{}, nil } + if p.atIdent("exit") { + return p.parseExit() + } + if p.atIdent("return") { + return p.parseReturn() + } if p.atIdent("break") { p.advance() return &breakStmt{}, nil @@ -184,15 +273,12 @@ func (p *parser) parseStatement() (stmt, error) { if p.atIdent("printf") { return p.parsePrintf() } - if p.atIdent("if") || p.atIdent("nextfile") || p.atIdent("exit") { + if p.atIdent("if") || p.atIdent("nextfile") { return nil, fmt.Errorf("control flow statements are not supported") } if p.atIdent("delete") { return p.parseDelete() } - if p.atIdent("getline") { - return nil, fmt.Errorf("getline is not supported") - } x, err := p.parseExpression(0) if err != nil { return nil, err @@ -200,6 +286,30 @@ func (p *parser) parseStatement() (stmt, error) { return &exprStmt{x: x}, nil } +func (p *parser) parseExit() (stmt, error) { + p.advance() + if p.at(tokRBrace) || p.at(tokEOF) || isSeparator(p.cur().kind) { + return &exitStmt{}, nil + } + status, err := p.parseExpression(0) + if err != nil { + return nil, err + } + return &exitStmt{status: status}, nil +} + +func (p *parser) parseReturn() (stmt, error) { + p.advance() + if p.at(tokRBrace) || p.at(tokEOF) || isSeparator(p.cur().kind) { + return &returnStmt{}, nil + } + x, err := p.parseExpression(0) + if err != nil { + return nil, err + } + return &returnStmt{value: x}, nil +} + func (p *parser) parseFor() (stmt, error) { p.advance() if !p.match(tokLParen) { @@ -225,11 +335,11 @@ func (p *parser) parseFor() (stmt, error) { if !p.match(tokRParen) { return nil, fmt.Errorf("expected ) after for loop") } - body, err := p.parseStatementGroup() + body, braced, err := p.parseStatementGroup() if err != nil { return nil, err } - return &forInStmt{varName: varName, arrayName: arrayName, body: body}, nil + return &forInStmt{varName: varName, arrayName: arrayName, body: body, endsBlock: braced}, nil } init, err := p.parseOptionalForExpr(tokSemicolon) if err != nil { @@ -252,11 +362,11 @@ func (p *parser) parseFor() (stmt, error) { if !p.match(tokRParen) { return nil, fmt.Errorf("expected ) after for loop") } - body, err := p.parseStatementGroup() + body, braced, err := p.parseStatementGroup() if err != nil { return nil, err } - return &forStmt{init: init, cond: cond, post: post, body: body}, nil + return &forStmt{init: init, cond: cond, post: post, body: body, endsBlock: braced}, nil } func (p *parser) parseOptionalForExpr(end tokenKind) (expr, error) { @@ -284,11 +394,11 @@ func (p *parser) parseWhile() (stmt, error) { if !p.match(tokRParen) { return nil, fmt.Errorf("expected ) after while condition") } - body, err := p.parseStatementGroup() + body, braced, err := p.parseStatementGroup() if err != nil { return nil, err } - return &whileStmt{cond: cond, body: body}, nil + return &whileStmt{cond: cond, body: body, endsBlock: braced}, nil } func (p *parser) parseIf() (stmt, error) { @@ -303,38 +413,42 @@ func (p *parser) parseIf() (stmt, error) { if !p.match(tokRParen) { return nil, fmt.Errorf("expected ) after if condition") } - thenStmts, err := p.parseStatementGroup() + thenStmts, thenBraced, err := p.parseStatementGroup() if err != nil { return nil, err } save := p.pos p.skipSeparators() var elseStmts []stmt + endsBlock := thenBraced if p.atIdent("else") { p.advance() - elseStmts, err = p.parseStatementGroup() + var elseBraced bool + elseStmts, elseBraced, err = p.parseStatementGroup() if err != nil { return nil, err } + endsBlock = elseBraced } else { p.pos = save } - return &ifStmt{cond: cond, thenStmts: thenStmts, elseStmts: elseStmts}, nil + return &ifStmt{cond: cond, thenStmts: thenStmts, elseStmts: elseStmts, endsBlock: endsBlock}, nil } -func (p *parser) parseStatementGroup() ([]stmt, error) { +func (p *parser) parseStatementGroup() ([]stmt, bool, error) { p.skipNewlines() if p.at(tokSemicolon) { - return nil, nil + return nil, false, nil } if p.match(tokLBrace) { - return p.parseStatementList() + stmts, err := p.parseStatementList() + return stmts, true, err } st, err := p.parseStatement() if err != nil { - return nil, err + return nil, false, err } - return []stmt{st}, nil + return []stmt{st}, false, nil } func (p *parser) parseDelete() (stmt, error) { @@ -350,14 +464,11 @@ func (p *parser) parseDelete() (stmt, error) { if !p.match(tokLBracket) { return &deleteStmt{name: name, all: true}, nil } - index, err := p.parseExpression(0) + indices, err := p.parseArrayIndices() if err != nil { return nil, err } - if !p.match(tokRBracket) { - return nil, fmt.Errorf("expected ] after array index") - } - return &deleteStmt{name: name, index: index}, nil + return &deleteStmt{name: name, indices: indices}, nil } func (p *parser) parsePrint() (stmt, error) { @@ -366,6 +477,14 @@ func (p *parser) parsePrint() (stmt, error) { if p.at(tokRBrace) || p.at(tokEOF) || isSeparator(p.cur().kind) { return ps, nil } + if p.at(tokPipe) { + pipe, err := p.parseOutputPipe() + if err != nil { + return nil, err + } + ps.pipe = pipe + return ps, nil + } old := p.stopPrintRedirect p.stopPrintRedirect = true defer func() { p.stopPrintRedirect = old }() @@ -375,8 +494,16 @@ func (p *parser) parsePrint() (stmt, error) { return nil, err } ps.args = append(ps.args, x) - if p.at(tokGT) || p.at(tokAppend) || p.at(tokPipe) { - return nil, fmt.Errorf("print redirection and command pipes are not supported") + if p.at(tokGT) || p.at(tokAppend) { + return nil, fmt.Errorf("print redirection is not supported") + } + if p.at(tokPipe) { + pipe, err := p.parseOutputPipe() + if err != nil { + return nil, err + } + ps.pipe = pipe + return ps, nil } if !p.match(tokComma) { break @@ -405,8 +532,16 @@ func (p *parser) parsePrintf() (stmt, error) { return nil, err } ps.args = append(ps.args, x) - if p.at(tokGT) || p.at(tokAppend) || p.at(tokPipe) { - return nil, fmt.Errorf("print redirection and command pipes are not supported") + if p.at(tokGT) || p.at(tokAppend) { + return nil, fmt.Errorf("print redirection is not supported") + } + if p.at(tokPipe) { + pipe, err := p.parseOutputPipe() + if err != nil { + return nil, err + } + ps.pipe = pipe + return ps, nil } if parenthesized { p.skipSeparators() @@ -424,9 +559,29 @@ func (p *parser) parsePrintf() (stmt, error) { } p.skipSeparators() } + if p.at(tokGT) || p.at(tokAppend) { + return nil, fmt.Errorf("print redirection is not supported") + } + if p.at(tokPipe) { + pipe, err := p.parseOutputPipe() + if err != nil { + return nil, err + } + ps.pipe = pipe + } return ps, nil } +func (p *parser) parseOutputPipe() (expr, error) { + if !p.match(tokPipe) { + return nil, fmt.Errorf("expected |") + } + old := p.stopPrintRedirect + p.stopPrintRedirect = false + defer func() { p.stopPrintRedirect = old }() + return p.parseExpression(0) +} + func (p *parser) skipNewlines() { for p.at(tokNewline) { p.advance() @@ -439,6 +594,25 @@ func (p *parser) parseExpression(minPrec int) (expr, error) { return nil, err } for { + if p.at(tokQuestion) { + if precTernary < minPrec { + break + } + p.advance() + thenExpr, err := p.parseExpression(0) + if err != nil { + return nil, err + } + if !p.match(tokColon) { + return nil, fmt.Errorf("expected : in conditional expression") + } + elseExpr, err := p.parseExpression(precAssign) + if err != nil { + return nil, err + } + left = &ternaryExpr{cond: left, then: thenExpr, els: elseExpr} + continue + } if p.at(tokInc) || p.at(tokDec) { if precPostfix < minPrec { break @@ -446,7 +620,7 @@ func (p *parser) parseExpression(minPrec int) (expr, error) { op := p.cur().lit p.advance() if !isAssignableExpr(left) { - return nil, fmt.Errorf("increment and decrement require variables") + return nil, fmt.Errorf("syntax error: increment and decrement require variables") } left = &incDecExpr{op: op, x: left} continue @@ -454,6 +628,17 @@ func (p *parser) parseExpression(minPrec int) (expr, error) { if p.stopPrintRedirect && (p.at(tokGT) || p.at(tokAppend) || p.at(tokPipe)) { break } + if p.at(tokPipe) && p.peek(1).kind == tokIdent && p.peek(1).lit == "getline" { + if precCompare < minPrec { + break + } + next, err := p.parseCommandGetline(left) + if err != nil { + return nil, err + } + left = next + continue + } if op, prec, assoc, ok := p.binaryOp(); ok { if prec < minPrec { break @@ -513,7 +698,10 @@ func (p *parser) parsePrefix() (expr, error) { return ®exExpr{pattern: tok.lit}, nil case tokIdent: p.advance() - if p.at(tokLParen) { + if tok.lit == "getline" { + return p.parseGetline(nil) + } + if p.at(tokLParen) && (tokensAdjacent(tok, p.cur()) || isKnownBuiltinFunction(tok.lit)) { return p.parseFunctionCall(tok.lit) } if tok.lit == "length" { @@ -537,6 +725,24 @@ func (p *parser) parsePrefix() (expr, error) { if err != nil { return nil, err } + if p.match(tokComma) { + parts := []expr{x} + for { + p.skipSeparators() + part, err := p.parseExpression(0) + if err != nil { + return nil, err + } + parts = append(parts, part) + p.skipSeparators() + if p.match(tokRParen) { + return &compositeExpr{parts: parts}, nil + } + if !p.match(tokComma) { + return nil, fmt.Errorf("expected , or ) in expression list") + } + } + } if !p.match(tokRParen) { return nil, fmt.Errorf("expected )") } @@ -563,31 +769,144 @@ func (p *parser) parsePrefix() (expr, error) { } } +func (p *parser) parseCommandGetline(source expr) (expr, error) { + if !p.match(tokPipe) { + return nil, fmt.Errorf("expected |") + } + if !p.atIdent("getline") { + return nil, fmt.Errorf("expected getline") + } + p.advance() + return p.parseGetline(source) +} + +func (p *parser) parseGetline(command expr) (expr, error) { + g := &getlineExpr{source: command} + if command != nil { + g.kind = getlineCommand + } else { + g.kind = getlineMain + } + if command == nil && p.at(tokLT) { + source, err := p.parseGetlineRedirection() + if err != nil { + return nil, err + } + g.kind = getlineFile + g.source = source + return g, nil + } + if p.canStartGetlineTarget() { + target, err := p.parseGetlineTarget() + if err != nil { + return nil, err + } + g.target = target + } + if command == nil && p.at(tokLT) { + source, err := p.parseGetlineRedirection() + if err != nil { + return nil, err + } + g.kind = getlineFile + g.source = source + } + return g, nil +} + +func (p *parser) parseGetlineRedirection() (expr, error) { + if !p.match(tokLT) { + return nil, fmt.Errorf("expected <") + } + return p.parseExpression(precConcat + 1) +} + +func (p *parser) canStartGetlineTarget() bool { + return p.at(tokIdent) || p.at(tokDollar) +} + +func (p *parser) parseGetlineTarget() (expr, error) { + switch tok := p.cur(); tok.kind { + case tokIdent: + p.advance() + if err := validateIdentifierReference(tok.lit); err != nil { + return nil, err + } + if p.at(tokLBracket) { + return p.parseArrayRef(tok.lit) + } + return &varExpr{name: tok.lit}, nil + case tokDollar: + return p.parseFieldRef() + default: + return nil, fmt.Errorf("syntax error: getline requires an assignable target") + } +} + +func tokensAdjacent(left, right token) bool { + return left.pos+len(left.lit) == right.pos +} + +func isKnownBuiltinFunction(name string) bool { + if name == "system" { + return true + } + if _, ok := supportedBuiltinFunctions[name]; ok { + return true + } + _, ok := unsupportedBuiltinFunctions[name] + return ok +} + func (p *parser) parseArrayRef(name string) (expr, error) { p.advance() - index, err := p.parseExpression(0) + indices, err := p.parseArrayIndices() if err != nil { return nil, err } - if !p.match(tokRBracket) { - return nil, fmt.Errorf("expected ] after array index") + return &arrayRefExpr{name: name, indices: indices}, nil +} + +func (p *parser) parseArrayIndices() ([]expr, error) { + indices := []expr{} + for { + p.skipSeparators() + index, err := p.parseExpression(0) + if err != nil { + return nil, err + } + indices = append(indices, index) + p.skipSeparators() + if p.match(tokRBracket) { + return indices, nil + } + if !p.match(tokComma) { + return nil, fmt.Errorf("expected , or ] after array index") + } } - return &arrayRefExpr{name: name, index: index}, nil } func (p *parser) parseFunctionCall(name string) (expr, error) { - if _, ok := supportedBuiltinFunctions[name]; !ok { - if name == "system" { - return nil, fmt.Errorf("system() is not supported") - } + if msg, ok := unsupportedExpressionKeyword(name); ok { + return nil, fmt.Errorf("%s", msg) + } + if name == "system" { + return nil, fmt.Errorf("system() is not supported") + } + _, supportedBuiltin := supportedBuiltinFunctions[name] + if _, ok := unsupportedBuiltinFunctions[name]; ok { return nil, fmt.Errorf("function calls are not supported") } p.advance() args := []expr{} p.skipSeparators() if p.match(tokRParen) { - if err := validateBuiltinCallArity(name, len(args)); err != nil { - return nil, err + if supportedBuiltin { + if err := validateBuiltinCallArity(name, len(args)); err != nil { + return nil, err + } + } else if !validVarName(name) { + return nil, fmt.Errorf("invalid function name %q", name) } return &callExpr{name: name}, nil } @@ -606,12 +925,276 @@ func (p *parser) parseFunctionCall(name string) (expr, error) { return nil, fmt.Errorf("expected , or ) in function call") } } - if err := validateBuiltinCallArity(name, len(args)); err != nil { - return nil, err + if supportedBuiltin { + if err := validateBuiltinCallArity(name, len(args)); err != nil { + return nil, err + } + } else if !validVarName(name) { + return nil, fmt.Errorf("invalid function name %q", name) } return &callExpr{name: name, args: args}, nil } +func validateFunctionName(name string) error { + if !validVarName(name) { + return fmt.Errorf("invalid function name %q", name) + } + if _, ok := supportedBuiltinFunctions[name]; ok { + return fmt.Errorf("%q is a built-in function, it cannot be redefined", name) + } + if _, ok := unsupportedBuiltinFunctions[name]; ok { + return fmt.Errorf("%q is a built-in function, it cannot be redefined", name) + } + if isReservedAwkVariableName(name) { + return fmt.Errorf("function name %q uses a reserved awk variable name", name) + } + if name == "system" { + return fmt.Errorf("system() is not supported") + } + return nil +} + +func validateFunctionParameterName(functionName, param string) error { + if !validVarName(param) { + return fmt.Errorf("invalid function parameter %q", param) + } + if functionName == param { + return fmt.Errorf("function %q cannot use function name as parameter name", functionName) + } + if isReservedAwkVariableName(param) { + return fmt.Errorf("parameter %q uses a reserved awk variable name", param) + } + if _, ok := supportedBuiltinFunctions[param]; ok { + return fmt.Errorf("parameter %q uses a built-in function name", param) + } + if _, ok := unsupportedBuiltinFunctions[param]; ok { + return fmt.Errorf("parameter %q uses a built-in function name", param) + } + if msg, ok := unsupportedExpressionKeyword(param); ok { + return fmt.Errorf("%s", msg) + } + return nil +} + +func validateLoopControlStatements(prog *program) error { + for _, r := range prog.rules { + if err := validateStmtListLoopControl(r.action, 0); err != nil { + return err + } + } + for _, fn := range prog.functions { + if err := validateStmtListLoopControl(fn.body, 0); err != nil { + return err + } + } + return nil +} + +func validateStmtListLoopControl(stmts []stmt, loopDepth int) error { + for _, st := range stmts { + if err := validateStmtLoopControl(st, loopDepth); err != nil { + return err + } + } + return nil +} + +func validateStmtLoopControl(st stmt, loopDepth int) error { + switch s := st.(type) { + case *ifStmt: + if err := validateStmtListLoopControl(s.thenStmts, loopDepth); err != nil { + return err + } + return validateStmtListLoopControl(s.elseStmts, loopDepth) + case *forInStmt: + return validateStmtListLoopControl(s.body, loopDepth+1) + case *forStmt: + return validateStmtListLoopControl(s.body, loopDepth+1) + case *whileStmt: + return validateStmtListLoopControl(s.body, loopDepth+1) + case *breakStmt: + if loopDepth == 0 { + return fmt.Errorf("break is not allowed outside a loop") + } + case *continueStmt: + if loopDepth == 0 { + return fmt.Errorf("continue is not allowed outside a loop") + } + } + return nil +} + +func validateUserFunctionNameReferences(prog *program) error { + if len(prog.functions) == 0 { + return nil + } + for _, r := range prog.rules { + if err := validateExprUserFunctionNameReferences(r.pattern, prog.functions, nil); err != nil { + return err + } + if err := validateStmtListUserFunctionNameReferences(r.action, prog.functions, nil); err != nil { + return err + } + } + for _, fn := range prog.functions { + locals := make(map[string]struct{}, len(fn.params)) + for _, param := range fn.params { + locals[param] = struct{}{} + } + if err := validateStmtListUserFunctionNameReferences(fn.body, prog.functions, locals); err != nil { + return err + } + } + return nil +} + +func validateStmtListUserFunctionNameReferences(stmts []stmt, functions map[string]*functionDef, locals map[string]struct{}) error { + for _, st := range stmts { + if err := validateStmtUserFunctionNameReferences(st, functions, locals); err != nil { + return err + } + } + return nil +} + +func validateStmtUserFunctionNameReferences(st stmt, functions map[string]*functionDef, locals map[string]struct{}) error { + switch s := st.(type) { + case *printStmt: + if err := validateExprListUserFunctionNameReferences(s.args, functions, locals); err != nil { + return err + } + return validateExprUserFunctionNameReferences(s.pipe, functions, locals) + case *printfStmt: + if err := validateExprListUserFunctionNameReferences(s.args, functions, locals); err != nil { + return err + } + return validateExprUserFunctionNameReferences(s.pipe, functions, locals) + case *ifStmt: + if err := validateExprUserFunctionNameReferences(s.cond, functions, locals); err != nil { + return err + } + if err := validateStmtListUserFunctionNameReferences(s.thenStmts, functions, locals); err != nil { + return err + } + return validateStmtListUserFunctionNameReferences(s.elseStmts, functions, locals) + case *forInStmt: + if err := validateNameNotUserFunction(s.varName, functions, locals); err != nil { + return err + } + if err := validateNameNotUserFunction(s.arrayName, functions, locals); err != nil { + return err + } + return validateStmtListUserFunctionNameReferences(s.body, functions, locals) + case *forStmt: + if err := validateExprUserFunctionNameReferences(s.init, functions, locals); err != nil { + return err + } + if err := validateExprUserFunctionNameReferences(s.cond, functions, locals); err != nil { + return err + } + if err := validateExprUserFunctionNameReferences(s.post, functions, locals); err != nil { + return err + } + return validateStmtListUserFunctionNameReferences(s.body, functions, locals) + case *whileStmt: + if err := validateExprUserFunctionNameReferences(s.cond, functions, locals); err != nil { + return err + } + return validateStmtListUserFunctionNameReferences(s.body, functions, locals) + case *exitStmt: + return validateExprUserFunctionNameReferences(s.status, functions, locals) + case *returnStmt: + return validateExprUserFunctionNameReferences(s.value, functions, locals) + case *deleteStmt: + if err := validateNameNotUserFunction(s.name, functions, locals); err != nil { + return err + } + return validateExprListUserFunctionNameReferences(s.indices, functions, locals) + case *exprStmt: + return validateExprUserFunctionNameReferences(s.x, functions, locals) + default: + return nil + } +} + +func validateExprListUserFunctionNameReferences(exprs []expr, functions map[string]*functionDef, locals map[string]struct{}) error { + for _, x := range exprs { + if err := validateExprUserFunctionNameReferences(x, functions, locals); err != nil { + return err + } + } + return nil +} + +func validateExprUserFunctionNameReferences(x expr, functions map[string]*functionDef, locals map[string]struct{}) error { + switch e := x.(type) { + case nil, *numberExpr, *stringExpr, *regexExpr: + return nil + case *varExpr: + return validateNameNotUserFunction(e.name, functions, locals) + case *arrayRefExpr: + if err := validateNameNotUserFunction(e.name, functions, locals); err != nil { + return err + } + return validateExprListUserFunctionNameReferences(e.indices, functions, locals) + case *compositeExpr: + return validateExprListUserFunctionNameReferences(e.parts, functions, locals) + case *fieldExpr: + return validateExprUserFunctionNameReferences(e.index, functions, locals) + case *groupedExpr: + return validateExprUserFunctionNameReferences(e.x, functions, locals) + case *unaryExpr: + return validateExprUserFunctionNameReferences(e.x, functions, locals) + case *binaryExpr: + if err := validateExprUserFunctionNameReferences(e.left, functions, locals); err != nil { + return err + } + return validateExprUserFunctionNameReferences(e.right, functions, locals) + case *ternaryExpr: + if err := validateExprUserFunctionNameReferences(e.cond, functions, locals); err != nil { + return err + } + if err := validateExprUserFunctionNameReferences(e.then, functions, locals); err != nil { + return err + } + return validateExprUserFunctionNameReferences(e.els, functions, locals) + case *rangeExpr: + if err := validateExprUserFunctionNameReferences(e.start, functions, locals); err != nil { + return err + } + return validateExprUserFunctionNameReferences(e.end, functions, locals) + case *assignExpr: + if err := validateExprUserFunctionNameReferences(e.left, functions, locals); err != nil { + return err + } + return validateExprUserFunctionNameReferences(e.right, functions, locals) + case *incDecExpr: + return validateExprUserFunctionNameReferences(e.x, functions, locals) + case *callExpr: + if _, ok := locals[e.name]; ok { + return fmt.Errorf("parameter %q cannot be called as a function", e.name) + } + return validateExprListUserFunctionNameReferences(e.args, functions, locals) + case *getlineExpr: + if err := validateExprUserFunctionNameReferences(e.target, functions, locals); err != nil { + return err + } + return validateExprUserFunctionNameReferences(e.source, functions, locals) + default: + return nil + } +} + +func validateNameNotUserFunction(name string, functions map[string]*functionDef, locals map[string]struct{}) error { + if _, ok := locals[name]; ok { + return nil + } + if _, ok := functions[name]; ok { + return fmt.Errorf("function %q cannot be used as a variable or array", name) + } + return nil +} + func validateBuiltinCallArity(name string, argc int) error { switch name { case "length": @@ -630,6 +1213,34 @@ func validateBuiltinCallArity(name string, argc int) error { if argc != 2 && argc != 3 { return fmt.Errorf("split expects 2 or 3 arguments") } + case "sub", "gsub": + if argc != 2 && argc != 3 { + return fmt.Errorf("%s expects 2 or 3 arguments", name) + } + case "match": + if argc != 2 && argc != 3 { + return fmt.Errorf("match expects 2 or 3 arguments") + } + case "sprintf": + if argc < 1 { + return fmt.Errorf("sprintf expects at least 1 argument") + } + case "gensub": + if argc != 3 && argc != 4 { + return fmt.Errorf("gensub expects 3 or 4 arguments") + } + case "strtonum": + if argc != 1 { + return fmt.Errorf("strtonum expects 1 argument") + } + case "asorti": + if argc != 1 && argc != 2 { + return fmt.Errorf("asorti expects 1 or 2 arguments") + } + case "close": + if argc != 1 { + return fmt.Errorf("close expects 1 argument") + } case "tolower", "toupper", "int": if argc != 1 { return fmt.Errorf("%s expects 1 argument", name) @@ -667,12 +1278,10 @@ func unsupportedExpressionKeyword(name string) (string, bool) { switch name { case "BEGIN", "END": return "BEGIN and END are reserved patterns", true - case "if", "while", "for", "next", "nextfile", "exit", "break", "continue": + case "if", "while", "for", "next", "nextfile", "exit", "break", "continue", "return", "function": return "control flow statements are not supported", true case "delete": return "arrays are not supported", true - case "getline": - return "getline is not supported", true case "printf": return "printf is not supported", true case "print": diff --git a/builtins/awk/parser_test.go b/builtins/awk/parser_test.go index 3330c2ec..b26a2713 100644 --- a/builtins/awk/parser_test.go +++ b/builtins/awk/parser_test.go @@ -25,8 +25,6 @@ func TestParseRejectsUnsafeFeatures(t *testing.T) { for _, src := range []string{ `{ system("sh") }`, `{ print $1 > "out" }`, - `{ "cmd" | getline }`, - `{ exit 1 }`, } { _, err := parseProgram(src) require.Error(t, err, src) diff --git a/builtins/awk/runtime.go b/builtins/awk/runtime.go index 31ede07d..9f83bd59 100644 --- a/builtins/awk/runtime.go +++ b/builtins/awk/runtime.go @@ -7,6 +7,7 @@ package awk import ( "bufio" + "bytes" "context" "errors" "fmt" @@ -25,6 +26,7 @@ const ( MaxRecordBytes = 1 << 20 MaxFields = 16_384 MaxVariableBytes = 1 << 20 + MaxPipeBytes = 5 << 20 maxFiniteFloat64 = 1.79769313486231570814527423731704357e+308 ) @@ -191,21 +193,38 @@ func numericPrefix(s string) string { } type runtime struct { - callCtx *builtins.CallContext - prog *program - vars map[string]value - arrays map[string]map[string]value - varSizes map[string]int - arraySizes map[arraySlot]int - varBytes int - rangeOn map[int]bool - environSet bool + callCtx *builtins.CallContext + prog *program + vars map[string]value + arrays map[string]map[string]value + varSizes map[string]int + arraySizes map[arraySlot]int + varBytes int + rangeOn map[int]bool + environSet bool + frames []callFrame + ctx context.Context + futureStmts []stmt + pipes map[string]*commandPipe + flushedPipes map[string]uint8 + pipeOrder []string + stdoutBuf bytes.Buffer + inputArgs []string + inputIndex int + mainInput *recordSource + mainHadInput bool + mainUsedStdin bool + mainDefaultStdin bool + fileInputs map[string]*recordSource + failedFileInputs map[string]bool + commandInputs map[string]*commandInputPipe record string fields []string filename string nr int fnr int + exitCode int } type arraySlot struct { @@ -213,59 +232,149 @@ type arraySlot struct { key string } +type callFrame struct { + locals map[string]*localVar +} + +type commandPipe struct { + command string + buf bytes.Buffer + writes int +} + +type commandInputPipe struct { + command string + source *recordSource + status uint8 +} + +type recordSource struct { + name string + rc io.ReadCloser + sc *bufio.Scanner + rt *runtime +} + +type localVar struct { + value value + valueSize int + valueSet bool + array map[string]value + arraySizes map[string]int + arrayAlias *localVar + globalArrayName string +} + func newRuntime(callCtx *builtins.CallContext, prog *program) *runtime { rt := &runtime{ - callCtx: callCtx, - prog: prog, - vars: make(map[string]value), - arrays: make(map[string]map[string]value), - varSizes: make(map[string]int), - arraySizes: make(map[arraySlot]int), - rangeOn: make(map[int]bool), + callCtx: callCtx, + prog: prog, + vars: make(map[string]value), + arrays: make(map[string]map[string]value), + varSizes: make(map[string]int), + arraySizes: make(map[arraySlot]int), + rangeOn: make(map[int]bool), + pipes: make(map[string]*commandPipe), + flushedPipes: make(map[string]uint8), + fileInputs: make(map[string]*recordSource), + failedFileInputs: make(map[string]bool), + commandInputs: make(map[string]*commandInputPipe), } rt.vars["FS"] = stringValue(" ") + rt.vars["RS"] = stringValue("\n") rt.vars["OFS"] = stringValue(" ") rt.vars["ORS"] = stringValue("\n") + rt.vars["SUBSEP"] = stringValue("\034") + rt.vars["RSTART"] = numberValue(0) + rt.vars["RLENGTH"] = numberValue(-1) return rt } func (rt *runtime) run(ctx context.Context, files []string) builtins.Result { + rt.inputArgs = append([]string{}, files...) + defer rt.closeAllInputs() + exited := false if err := rt.runRules(ctx, ruleBegin); err != nil { - rt.callCtx.Errf("awk: %v\n", err) - return builtins.Result{Code: 1} - } - if rt.needsInput() { - if len(files) == 0 { - files = []string{"-"} + if code, ok := exitCodeFromError(err); ok { + rt.exitCode = code + exited = true + } else { + return rt.errorResult(err) } - ranInput := false - for _, file := range files { - assigned, err := rt.applyOperandAssignment(file) + } + if !exited && rt.needsInput() { + for { + rec, ok, err := rt.readMainRecord(ctx) if err != nil { - rt.callCtx.Errf("awk: %v\n", err) - return builtins.Result{Code: 1} + if code, ok := exitCodeFromError(err); ok { + rt.exitCode = code + exited = true + break + } + return rt.errorResult(err) } - if assigned { - continue + if !ok { + break } - ranInput = true - if err := rt.runFile(ctx, file); err != nil { - rt.callCtx.Errf("awk: %s: %v\n", file, err) - return builtins.Result{Code: 1} + if err := rt.setRecord(rec); err != nil { + return rt.errorResult(err) } - } - if !ranInput { - if err := rt.runFile(ctx, "-"); err != nil { - rt.callCtx.Errf("awk: -: %v\n", err) - return builtins.Result{Code: 1} + if err := rt.runRules(ctx, ruleNormal); err != nil { + if errors.Is(err, errNextRecord) { + continue + } + if code, ok := exitCodeFromError(err); ok { + rt.exitCode = code + exited = true + break + } + return rt.errorResult(err) } } } if err := rt.runRules(ctx, ruleEnd); err != nil { - rt.callCtx.Errf("awk: %v\n", err) - return builtins.Result{Code: 1} + if code, ok := exitCodeFromError(err); ok { + rt.exitCode = code + } else { + return rt.errorResult(err) + } + } + if err := rt.closeAllCommandPipes(ctx); err != nil { + return rt.errorResult(err) + } + rt.flushStdoutBuffer() + return builtins.Result{Code: normalizeAwkExitCode(rt.exitCode)} +} + +func (rt *runtime) errorResult(err error) builtins.Result { + rt.callCtx.Errf("awk: %v\n", err) + code := uint8(1) + if isFatalError(err) { + code = 2 + } + return builtins.Result{Code: code} +} + +func isFatalError(err error) bool { + const prefix = "fatal: " + msg := err.Error() + return len(msg) >= len(prefix) && msg[:len(prefix)] == prefix +} + +func exitCodeFromError(err error) (int, bool) { + exit, ok := err.(*exitError) + if ok { + return exit.code, true + } + return 0, false +} + +func normalizeAwkExitCode(code int) uint8 { + code %= 256 + if code < 0 { + code += 256 } - return builtins.Result{} + return uint8(code) } func (rt *runtime) ensureEnviron() { @@ -306,48 +415,111 @@ func (rt *runtime) needsInput() bool { return false } -func (rt *runtime) runFile(ctx context.Context, file string) error { +func (rt *runtime) readMainRecord(ctx context.Context) (string, bool, error) { + for { + if rt.mainInput == nil { + ok, err := rt.openNextMainInput(ctx) + if err != nil || !ok { + return "", false, err + } + } + rec, ok, err := rt.mainInput.readRecord(ctx) + if err != nil { + return "", false, fmt.Errorf("%s: %v", rt.mainInput.name, err) + } + if ok { + rt.nr++ + rt.fnr++ + return rec, true, nil + } + rt.mainInput.close() + rt.mainInput = nil + } +} + +func (rt *runtime) openNextMainInput(ctx context.Context) (bool, error) { + for rt.inputIndex < len(rt.inputArgs) { + arg := rt.inputArgs[rt.inputIndex] + rt.inputIndex++ + assigned, err := rt.applyOperandAssignment(arg) + if err != nil { + return false, err + } + if assigned { + continue + } + return rt.openMainInput(ctx, arg) + } + if !rt.mainHadInput && !rt.mainDefaultStdin { + rt.mainDefaultStdin = true + return rt.openMainInput(ctx, "-") + } + return false, nil +} + +func (rt *runtime) openMainInput(ctx context.Context, file string) (bool, error) { rc, err := rt.openInput(ctx, file) if err != nil { - return err + return false, fmt.Errorf("fatal: cannot open file `%s' for reading: %v", file, err) + } + rt.mainHadInput = true + if file == "-" { + rt.mainUsedStdin = true } - defer rc.Close() rt.filename = file rt.fnr = 0 + rt.mainInput = rt.newRecordSource(file, rc) + return true, nil +} + +func (rt *runtime) newRecordSource(name string, rc io.ReadCloser) *recordSource { + src := &recordSource{name: name, rc: rc, rt: rt} sc := bufio.NewScanner(rc) - sc.Split(scanAwkRecord) + sc.Split(func(data []byte, atEOF bool) (int, []byte, error) { + return scanAwkRecord(data, atEOF, src.recordSeparator()) + }) sc.Buffer(make([]byte, 4096), MaxRecordBytes+1) - for sc.Scan() { - if err := ctx.Err(); err != nil { - return err - } - rec := sc.Text() - if len(rec) > MaxRecordBytes { - return fmt.Errorf("record exceeds %d bytes", MaxRecordBytes) - } - if err := rt.setRecord(rec); err != nil { - return err - } - rt.nr++ - rt.fnr++ - if err := rt.runRules(ctx, ruleNormal); err != nil { - if errors.Is(err, errNextRecord) { - continue - } - return err + src.sc = sc + return src +} + +func (src *recordSource) recordSeparator() string { + if src == nil || src.rt == nil { + return "\n" + } + return src.rt.getVar("RS").String() +} + +func (src *recordSource) readRecord(ctx context.Context) (string, bool, error) { + if err := ctx.Err(); err != nil { + return "", false, err + } + if !src.sc.Scan() { + if err := src.sc.Err(); err != nil { + return "", false, err } + return "", false, nil } - if err := sc.Err(); err != nil { - return err + rec := src.sc.Text() + if len(rec) > MaxRecordBytes { + return "", false, fmt.Errorf("record exceeds %d bytes", MaxRecordBytes) } - return nil + return rec, true, nil } -func scanAwkRecord(data []byte, atEOF bool) (int, []byte, error) { - for i, b := range data { - if b == '\n' { - return i + 1, data[:i], nil - } +func (src *recordSource) close() { + if src != nil && src.rc != nil { + src.rc.Close() + } +} + +func scanAwkRecord(data []byte, atEOF bool, rs string) (int, []byte, error) { + if err := validateRS(rs); err != nil { + return 0, nil, err + } + sep := []byte(rs) + if i := indexBytes(data, sep); i >= 0 { + return i + len(sep), data[:i], nil } if atEOF { if len(data) == 0 { @@ -358,6 +530,25 @@ func scanAwkRecord(data []byte, atEOF bool) (int, []byte, error) { return 0, nil, nil } +func indexBytes(data, sep []byte) int { + if len(sep) == 0 { + return -1 + } + for i := 0; i+len(sep) <= len(data); i++ { + matched := true + for j := range sep { + if data[i+j] != sep[j] { + matched = false + break + } + } + if matched { + return i + } + } + return -1 +} + func (rt *runtime) openInput(ctx context.Context, file string) (io.ReadCloser, error) { if file == "-" { if rt.callCtx.Stdin == nil { @@ -372,7 +563,501 @@ func (rt *runtime) openInput(ctx context.Context, file string) (io.ReadCloser, e return f, nil } +func (rt *runtime) writeCommandPipe(ctx context.Context, target expr, out string) error { + commandValue, err := rt.eval(target) + if err != nil { + return err + } + command := commandValue.String() + if command == "" { + return fmt.Errorf("expression for `|' redirection has null string value") + } + pipe, err := rt.commandPipe(command) + if err != nil { + return err + } + if len(out) > MaxPipeBytes-pipe.buf.Len() { + return fmt.Errorf("command pipe %q input exceeds %d bytes", command, MaxPipeBytes) + } + if _, err := pipe.buf.WriteString(out); err != nil { + return err + } + pipe.writes++ + return ctx.Err() +} + +func (rt *runtime) commandPipe(command string) (*commandPipe, error) { + if pipe, ok := rt.pipes[command]; ok { + return pipe, nil + } + delete(rt.flushedPipes, command) + pipe := &commandPipe{command: command} + rt.pipes[command] = pipe + rt.pipeOrder = append(rt.pipeOrder, command) + return pipe, nil +} + +func (rt *runtime) closeCommandPipe(ctx context.Context, command string, flushStdoutBefore bool) (uint8, bool, error) { + pipe, ok := rt.pipes[command] + if !ok { + if status, ok := rt.flushedPipes[command]; ok { + delete(rt.flushedPipes, command) + return status, true, nil + } + return 0, false, nil + } + delete(rt.pipes, command) + rt.removeCommandPipeOrder(command) + if flushStdoutBefore { + rt.flushStdoutBuffer() + } + status, err := rt.runCommandPipe(ctx, pipe) + return status, true, err +} + +func (rt *runtime) removeCommandPipeOrder(command string) { + for i, candidate := range rt.pipeOrder { + if candidate == command { + copy(rt.pipeOrder[i:], rt.pipeOrder[i+1:]) + rt.pipeOrder = rt.pipeOrder[:len(rt.pipeOrder)-1] + return + } + } +} + +func (rt *runtime) closeAllCommandPipes(ctx context.Context) error { + for len(rt.pipeOrder) > 0 { + command := rt.pipeOrder[0] + _, _, err := rt.closeCommandPipe(ctx, command, false) + if err != nil { + return err + } + } + return nil +} + +func (rt *runtime) flushCommandPipesForStdout(ctx context.Context, remaining []stmt) error { + for _, command := range append([]string(nil), rt.pipeOrder...) { + if rt.commandPipeNextAction(command, remaining) != commandPipeActionNone { + continue + } + status, ok, err := rt.closeCommandPipe(ctx, command, false) + if err != nil { + return err + } + if ok { + rt.flushedPipes[command] = status + } + } + return nil +} + +func (rt *runtime) shouldBufferStdoutForPipes(remaining []stmt) bool { + if rt.stdoutBuf.Len() > 0 { + return true + } + for _, command := range rt.pipeOrder { + if rt.commandPipeNextAction(command, remaining) != commandPipeActionNone { + return true + } + } + return false +} + +func (rt *runtime) commandPipeNextAction(command string, stmts []stmt) commandPipeAction { + return rt.stmtsCommandPipeAction(command, stmts, nil) +} + +type commandPipeAction int + +const ( + commandPipeActionNone commandPipeAction = iota + commandPipeActionWrite + commandPipeActionClose +) + +func (rt *runtime) stmtsCommandPipeAction(command string, stmts []stmt, seen map[string]bool) commandPipeAction { + for _, st := range stmts { + if action := rt.stmtCommandPipeAction(command, st, seen); action != commandPipeActionNone { + return action + } + } + return commandPipeActionNone +} + +func (rt *runtime) stmtCommandPipeAction(command string, st stmt, seen map[string]bool) commandPipeAction { + switch s := st.(type) { + case *printStmt: + if action := rt.exprsCommandPipeAction(command, s.args, seen); action != commandPipeActionNone { + return action + } + return pipeExprCommandPipeAction(s.pipe, command) + case *printfStmt: + if action := rt.exprsCommandPipeAction(command, s.args, seen); action != commandPipeActionNone { + return action + } + return pipeExprCommandPipeAction(s.pipe, command) + case *ifStmt: + if action := rt.exprCommandPipeAction(command, s.cond, seen); action != commandPipeActionNone { + return action + } + return mergeBranchCommandPipeAction( + rt.stmtsCommandPipeAction(command, s.thenStmts, seen), + rt.stmtsCommandPipeAction(command, s.elseStmts, seen), + ) + case *forInStmt: + return rt.stmtsCommandPipeAction(command, s.body, seen) + case *forStmt: + forParts := []expr{s.init, s.cond, s.post} + if action := rt.exprsCommandPipeAction(command, forParts, seen); action != commandPipeActionNone { + return action + } + return rt.stmtsCommandPipeAction(command, s.body, seen) + case *whileStmt: + if action := rt.exprCommandPipeAction(command, s.cond, seen); action != commandPipeActionNone { + return action + } + return rt.stmtsCommandPipeAction(command, s.body, seen) + case *deleteStmt: + return rt.exprsCommandPipeAction(command, s.indices, seen) + case *exitStmt: + return rt.exprCommandPipeAction(command, s.status, seen) + case *returnStmt: + return rt.exprCommandPipeAction(command, s.value, seen) + case *exprStmt: + return rt.exprCommandPipeAction(command, s.x, seen) + default: + return commandPipeActionNone + } +} + +func mergeBranchCommandPipeAction(left, right commandPipeAction) commandPipeAction { + if left == commandPipeActionWrite || right == commandPipeActionWrite { + return commandPipeActionWrite + } + if left == commandPipeActionClose || right == commandPipeActionClose { + return commandPipeActionClose + } + return commandPipeActionNone +} + +func pipeExprCommandPipeAction(pipe expr, command string) commandPipeAction { + if pipe == nil { + return commandPipeActionNone + } + if static, ok := staticStringExpr(pipe); ok { + if static == command { + return commandPipeActionWrite + } + return commandPipeActionNone + } + return commandPipeActionWrite +} + +func (rt *runtime) exprsCommandPipeAction(command string, exprs []expr, seen map[string]bool) commandPipeAction { + for _, x := range exprs { + if action := rt.exprCommandPipeAction(command, x, seen); action != commandPipeActionNone { + return action + } + } + return commandPipeActionNone +} + +func (rt *runtime) exprCommandPipeAction(command string, x expr, seen map[string]bool) commandPipeAction { + if x == nil { + return commandPipeActionNone + } + switch e := x.(type) { + case *arrayRefExpr: + return rt.exprsCommandPipeAction(command, e.indices, seen) + case *compositeExpr: + return rt.exprsCommandPipeAction(command, e.parts, seen) + case *fieldExpr: + return rt.exprCommandPipeAction(command, e.index, seen) + case *groupedExpr: + return rt.exprCommandPipeAction(command, e.x, seen) + case *unaryExpr: + return rt.exprCommandPipeAction(command, e.x, seen) + case *binaryExpr: + if action := rt.exprCommandPipeAction(command, e.left, seen); action != commandPipeActionNone { + return action + } + return rt.exprCommandPipeAction(command, e.right, seen) + case *ternaryExpr: + if action := rt.exprCommandPipeAction(command, e.cond, seen); action != commandPipeActionNone { + return action + } + return mergeBranchCommandPipeAction( + rt.exprCommandPipeAction(command, e.then, seen), + rt.exprCommandPipeAction(command, e.els, seen), + ) + case *assignExpr: + if action := rt.exprCommandPipeAction(command, e.left, seen); action != commandPipeActionNone { + return action + } + return rt.exprCommandPipeAction(command, e.right, seen) + case *incDecExpr: + return rt.exprCommandPipeAction(command, e.x, seen) + case *callExpr: + if action := rt.exprsCommandPipeAction(command, e.args, seen); action != commandPipeActionNone { + return action + } + if e.name == "close" && len(e.args) == 1 { + if static, ok := staticStringExpr(e.args[0]); ok { + if static == command { + return commandPipeActionClose + } + return commandPipeActionNone + } + return commandPipeActionClose + } + if fn, ok := rt.prog.functions[e.name]; ok { + if seen[e.name] { + return commandPipeActionNone + } + nextSeen := make(map[string]bool, len(seen)+1) + for name, active := range seen { + nextSeen[name] = active + } + nextSeen[e.name] = true + return rt.stmtsCommandPipeAction(command, fn.body, nextSeen) + } + case *getlineExpr: + if action := rt.exprCommandPipeAction(command, e.target, seen); action != commandPipeActionNone { + return action + } + return rt.exprCommandPipeAction(command, e.source, seen) + } + return commandPipeActionNone +} + +func staticStringExpr(x expr) (string, bool) { + switch e := x.(type) { + case *stringExpr: + return e.value, true + case *groupedExpr: + return staticStringExpr(e.x) + default: + return "", false + } +} + +func (rt *runtime) runCommandPipe(ctx context.Context, pipe *commandPipe) (uint8, error) { + if rt.callCtx.RunScriptWithStdin == nil { + return 127, fmt.Errorf("command pipes are not available") + } + dir := "" + if rt.callCtx.WorkDir != nil { + dir = rt.callCtx.WorkDir() + } + return rt.callCtx.RunScriptWithStdin(ctx, dir, pipe.command, bytes.NewReader(pipe.buf.Bytes()), rt.callCtx.Stdout) +} + +func (rt *runtime) writeStdoutString(ctx context.Context, s string, remaining []stmt) error { + if s != "" { + if rt.shouldBufferStdoutForPipes(remaining) { + _, err := rt.stdoutBuf.WriteString(s) + if err != nil { + return err + } + return ctx.Err() + } + if err := rt.flushCommandPipesForStdout(ctx, remaining); err != nil { + return err + } + } + rt.callCtx.Out(s) + return nil +} + +func (rt *runtime) flushStdoutBuffer() { + if rt.stdoutBuf.Len() == 0 { + return + } + rt.callCtx.Out(rt.stdoutBuf.String()) + rt.stdoutBuf.Reset() +} + +func (rt *runtime) getlineFileRecord(ctx context.Context, name string) (string, int, error) { + src, ok := rt.fileInputs[name] + if !ok { + opened, err := rt.openFileInput(ctx, name) + if err != nil { + return "", 0, err + } + if opened == nil { + return "", -1, nil + } + src = opened + } + rec, ok, err := src.readRecord(ctx) + if err != nil { + rt.setErrno(err) + return "", -1, nil + } + if !ok { + return "", 0, nil + } + return rec, 1, nil +} + +func (rt *runtime) openFileInput(ctx context.Context, name string) (*recordSource, error) { + if name == "" { + return nil, fmt.Errorf("fatal: expression for `<' redirection has null string value") + } + rc, err := rt.openInput(ctx, name) + if err != nil { + rt.failedFileInputs[name] = true + rt.setErrno(err) + return nil, nil + } + src := rt.newRecordSource(name, rc) + rt.fileInputs[name] = src + delete(rt.failedFileInputs, name) + return src, nil +} + +func (rt *runtime) getlineCommandRecord(ctx context.Context, command string) (string, int, error) { + pipe, ok := rt.commandInputs[command] + if !ok { + opened, err := rt.openCommandInput(ctx, command) + if err != nil { + return "", 0, err + } + pipe = opened + } + rec, ok, err := pipe.source.readRecord(ctx) + if err != nil { + rt.setErrno(err) + return "", -1, nil + } + if !ok { + return "", 0, nil + } + return rec, 1, nil +} + +func (rt *runtime) openCommandInput(ctx context.Context, command string) (*commandInputPipe, error) { + if command == "" { + return nil, fmt.Errorf("fatal: expression for `|' redirection has null string value") + } + if rt.callCtx.RunScriptWithStdin == nil { + return nil, fmt.Errorf("command pipes are not available") + } + dir := "" + if rt.callCtx.WorkDir != nil { + dir = rt.callCtx.WorkDir() + } + var out limitedBuffer + out.max = MaxPipeBytes + status, err := rt.callCtx.RunScriptWithStdin(ctx, dir, command, rt.commandInputStdin(), &out) + if out.err != nil { + return nil, out.err + } + if err != nil { + return nil, err + } + pipe := &commandInputPipe{ + command: command, + source: rt.newRecordSource(command, io.NopCloser(bytes.NewReader(out.buf.Bytes()))), + status: status, + } + rt.commandInputs[command] = pipe + return pipe, nil +} + +func (rt *runtime) commandInputStdin() io.Reader { + if rt.callCtx.Stdin != nil && !rt.mainUsedStdin { + return rt.callCtx.Stdin + } + return strings.NewReader("") +} + +type limitedBuffer struct { + buf bytes.Buffer + max int + err error +} + +func (w *limitedBuffer) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + if len(p) > w.max-w.buf.Len() { + remaining := w.max - w.buf.Len() + if remaining > 0 { + _, _ = w.buf.Write(p[:remaining]) + } + w.err = fmt.Errorf("command pipe output exceeds %d bytes", w.max) + return len(p), w.err + } + n, err := w.buf.Write(p) + if err != nil { + w.err = err + } + return n, err +} + +func (rt *runtime) closeCommandInput(command string) (uint8, bool, error) { + pipe, ok := rt.commandInputs[command] + if !ok { + return 0, false, nil + } + pipe.source.close() + delete(rt.commandInputs, command) + return pipe.status, true, nil +} + +func (rt *runtime) closeInputFile(name string) (int, bool) { + if src, ok := rt.fileInputs[name]; ok { + src.close() + delete(rt.fileInputs, name) + return 0, true + } + if rt.failedFileInputs[name] { + delete(rt.failedFileInputs, name) + return -1, true + } + return 0, false +} + +func (rt *runtime) closeAllInputs() { + if rt.mainInput != nil { + rt.mainInput.close() + rt.mainInput = nil + } + for name, src := range rt.fileInputs { + src.close() + delete(rt.fileInputs, name) + } + for command, pipe := range rt.commandInputs { + pipe.source.close() + delete(rt.commandInputs, command) + } +} + +func (rt *runtime) setErrno(err error) { + if err == nil { + return + } + msg := err.Error() + if rt.callCtx.PortableErr != nil { + msg = rt.callCtx.PortableErr(err) + } + if len(msg) > 0 && msg[0] >= 'a' && msg[0] <= 'z' { + msg = string(msg[0]-'a'+'A') + msg[1:] + } + _ = rt.setVar("ERRNO", stringValue(msg)) +} + +func (rt *runtime) setErrnoString(msg string) { + _ = rt.setVar("ERRNO", stringValue(msg)) +} + func (rt *runtime) runRules(ctx context.Context, kind ruleKind) error { + prevCtx := rt.ctx + rt.ctx = ctx + defer func() { rt.ctx = prevCtx }() for i := range rt.prog.rules { r := &rt.prog.rules[i] if err := ctx.Err(); err != nil { @@ -391,12 +1076,12 @@ func (rt *runtime) runRules(ctx context.Context, kind ruleKind) error { } } if r.action == nil { - if err := rt.printValues([]value{rt.field(0)}); err != nil { + if err := rt.writeStdoutString(ctx, rt.formatPrintValues([]value{rt.field(0)}), rt.ruleFuture(kind, i+1)); err != nil { return err } continue } - if err := rt.execStatements(ctx, r.action); err != nil { + if err := rt.execStatementsWithFuture(ctx, r.action, rt.ruleFuture(kind, i+1)); err != nil { if errors.Is(err, errNextRecord) { if kind == ruleNormal { return err @@ -409,6 +1094,30 @@ func (rt *runtime) runRules(ctx context.Context, kind ruleKind) error { return nil } +func (rt *runtime) ruleFuture(kind ruleKind, nextRule int) []stmt { + var future []stmt + future = rt.appendRuleActions(future, kind, nextRule) + switch kind { + case ruleBegin: + future = rt.appendRuleActions(future, ruleNormal, 0) + future = rt.appendRuleActions(future, ruleEnd, 0) + case ruleNormal: + future = rt.appendRuleActions(future, ruleNormal, 0) + future = rt.appendRuleActions(future, ruleEnd, 0) + } + return future +} + +func (rt *runtime) appendRuleActions(dst []stmt, kind ruleKind, start int) []stmt { + for i := start; i < len(rt.prog.rules); i++ { + r := rt.prog.rules[i] + if r.kind == kind && r.action != nil { + dst = append(dst, r.action...) + } + } + return dst +} + func (rt *runtime) matchPattern(ruleIndex int, x expr) (bool, error) { if rx, ok := x.(*rangeExpr); ok { return rt.matchRangePattern(ruleIndex, rx) @@ -446,7 +1155,7 @@ func (rt *runtime) matchRangePattern(ruleIndex int, x *rangeExpr) (bool, error) func (rt *runtime) matchSimplePattern(x expr) (bool, error) { if rx, ok := x.(*regexExpr); ok { - re, err := compileRegex(rx.pattern) + re, err := rt.compileRegex(rx.pattern) if err != nil { return false, err } @@ -465,7 +1174,7 @@ func (rt *runtime) setRecord(rec string) error { } rt.record = rec fs := rt.getVar("FS").String() - fields, err := splitAwkFields(rec, fs) + fields, err := rt.splitAwkFields(rec, fs) if err != nil { return err } @@ -558,7 +1267,7 @@ func (rt *runtime) setNF(n int) error { return rt.rebuildRecordFromFields() } -func splitAwkFields(s, fs string) ([]string, error) { +func (rt *runtime) splitAwkFields(s, fs string) ([]string, error) { if fs == " " { return splitAwkWhitespaceFields(s), nil } @@ -571,7 +1280,7 @@ func splitAwkFields(s, fs string) ([]string, error) { if isSingleRune(fs) { return strings.Split(s, fs), nil } - return splitAwkRegex(s, fs) + return rt.splitAwkRegex(s, fs) } func splitAwkWhitespaceFields(rec string) []string { @@ -606,14 +1315,14 @@ func splitAwkChars(s string) []string { return chars } -func splitAwkRegex(s, pattern string) ([]string, error) { +func (rt *runtime) splitAwkRegex(s, pattern string) ([]string, error) { if s == "" { return nil, nil } if pattern == "" { return splitAwkChars(s), nil } - re, err := compileRegex(pattern) + re, err := rt.compileRegex(pattern) if err != nil { return nil, err } @@ -644,7 +1353,51 @@ func (rt *runtime) field(n int) value { return inputStringValue(rt.fields[n-1]) } +func (rt *runtime) currentFrame() *callFrame { + if len(rt.frames) == 0 { + return nil + } + return &rt.frames[len(rt.frames)-1] +} + +func (rt *runtime) lookupLocal(name string) *localVar { + frame := rt.currentFrame() + if frame == nil { + return nil + } + return frame.locals[name] +} + +func rootLocalVar(v *localVar) *localVar { + for v != nil && v.arrayAlias != nil { + v = v.arrayAlias + } + return v +} + +func (rt *runtime) localIsArray(v *localVar) bool { + root := rootLocalVar(v) + if root == nil { + return false + } + if root.globalArrayName != "" { + return rt.isGlobalArray(root.globalArrayName) || isBuiltinArrayName(root.globalArrayName) + } + return root.array != nil +} + func (rt *runtime) getVar(name string) value { + if local := rt.lookupLocal(name); local != nil { + root := rootLocalVar(local) + if rt.localIsArray(root) { + return unassignedValue() + } + if local.valueSet { + return local.value + } + rt.markLocalScalarRead(local) + return unassignedValue() + } switch name { case "NF": return numberValue(float64(len(rt.fields))) @@ -668,6 +1421,13 @@ func (rt *runtime) getVar(name string) value { } func (rt *runtime) setVar(name string, v value) error { + if local := rt.lookupLocal(name); local != nil { + root := rootLocalVar(local) + if rt.localIsArray(root) { + return fmt.Errorf("cannot use array %s as scalar", name) + } + return rt.setLocalScalar(local, v) + } if rt.isArray(name) { return fmt.Errorf("cannot use array %s as scalar", name) } @@ -683,6 +1443,10 @@ func (rt *runtime) setVar(name string, v value) error { if err := validateFS(v.String()); err != nil { return err } + case "RS": + if err := validateRS(v.String()); err != nil { + return err + } } size := len(v.String()) if size > MaxVariableBytes { @@ -698,38 +1462,146 @@ func (rt *runtime) setVar(name string, v value) error { return nil } +func (rt *runtime) setLocalScalar(local *localVar, v value) error { + root := rootLocalVar(local) + size := len(v.String()) + if size > MaxVariableBytes { + return fmt.Errorf("variable value exceeds %d bytes", MaxVariableBytes) + } + if rt.varBytes-local.valueSize+size > MaxVariableBytes { + return fmt.Errorf("variable storage limit exceeded (%d bytes total)", rt.varBytes-local.valueSize+size) + } + rt.varBytes = rt.varBytes - local.valueSize + size + local.valueSize = size + local.value = v + local.valueSet = true + if root != nil && root != local && !rt.localIsArray(root) { + root.valueSet = true + if root.globalArrayName != "" { + rt.markGlobalScalarName(root.globalArrayName) + } + } + local.arrayAlias = nil + local.globalArrayName = "" + local.array = nil + local.arraySizes = nil + return nil +} + +func (rt *runtime) markLocalScalarRead(local *localVar) { + root := rootLocalVar(local) + if root == nil || rt.localIsArray(root) { + return + } + root.value = unassignedValue() + root.valueSet = true + if root.globalArrayName != "" { + rt.markGlobalScalarName(root.globalArrayName) + } +} + func (rt *runtime) isArray(name string) bool { + if local := rt.lookupLocal(name); local != nil { + return rt.localIsArray(local) + } + return rt.isGlobalArray(name) +} + +func (rt *runtime) isGlobalArray(name string) bool { _, ok := rt.arrays[name] return ok } -func (rt *runtime) getArrayElem(name, key string) (value, error) { +func (rt *runtime) localArrayStorage(name string, create bool) (map[string]value, *localVar, string, bool, error) { + local := rt.lookupLocal(name) + if local == nil { + return nil, nil, "", false, nil + } + root := rootLocalVar(local) + if root.valueSet && root.array == nil { + return nil, nil, "", true, fmt.Errorf("cannot use scalar %s as array", name) + } + if root.globalArrayName != "" { + actual := root.globalArrayName + rt.ensureBuiltinArray(actual) + if err := rt.validateArrayName(actual); err != nil { + return nil, nil, "", true, err + } + if create || rt.arrays[actual] != nil { + rt.markArrayName(actual) + } + return rt.arrays[actual], root, actual, true, nil + } + if root.array == nil && create { + root.array = make(map[string]value) + root.arraySizes = make(map[string]int) + } + return root.array, root, "", true, nil +} + +func (rt *runtime) ensureLocalArray(name string) (map[string]value, *localVar, string, bool, error) { + elems, local, globalName, handled, err := rt.localArrayStorage(name, true) + if handled || err != nil { + return elems, local, globalName, handled, err + } rt.ensureBuiltinArray(name) if err := rt.validateArrayName(name); err != nil { - return value{}, err + return nil, nil, "", false, err } rt.markArrayName(name) - if v, ok := rt.arrays[name][key]; ok { + return rt.arrays[name], nil, name, false, nil +} + +func (rt *runtime) getArrayElem(name, key string) (value, error) { + elems, local, globalName, handled, err := rt.ensureLocalArray(name) + if err != nil { + return value{}, err + } + if v, ok := elems[key]; ok { return v, nil } v := unassignedValue() - if err := rt.setArrayElem(name, key, v); err != nil { + if handled { + if err := rt.setLocalArrayElem(local, globalName, key, v); err != nil { + return value{}, err + } + return v, nil + } + if err := rt.setGlobalArrayElem(name, key, v); err != nil { return value{}, err } return v, nil } func (rt *runtime) hasArrayElem(name, key string) (bool, error) { - rt.ensureBuiltinArray(name) - if err := rt.validateArrayName(name); err != nil { + elems, _, _, handled, err := rt.localArrayStorage(name, true) + if err != nil { return false, err } - rt.markArrayName(name) - _, ok := rt.arrays[name][key] + if !handled { + rt.ensureBuiltinArray(name) + if err := rt.validateArrayName(name); err != nil { + return false, err + } + rt.markArrayName(name) + elems = rt.arrays[name] + } + _, ok := elems[key] return ok, nil } func (rt *runtime) setArrayElem(name, key string, v value) error { + _, local, globalName, handled, err := rt.ensureLocalArray(name) + if err != nil { + return err + } + if handled { + return rt.setLocalArrayElem(local, globalName, key, v) + } + return rt.setGlobalArrayElem(name, key, v) +} + +func (rt *runtime) setGlobalArrayElem(name, key string, v value) error { rt.ensureBuiltinArray(name) if err := rt.validateArrayName(name); err != nil { return err @@ -750,13 +1622,33 @@ func (rt *runtime) setArrayElem(name, key string, v value) error { return nil } +func (rt *runtime) setLocalArrayElem(local *localVar, globalName, key string, v value) error { + if globalName != "" { + return rt.setGlobalArrayElem(globalName, key, v) + } + root := rootLocalVar(local) + if root.array == nil { + root.array = make(map[string]value) + root.arraySizes = make(map[string]int) + } + size := len(key) + len(v.String()) + if size > MaxVariableBytes { + return fmt.Errorf("array element exceeds %d bytes", MaxVariableBytes) + } + old := root.arraySizes[key] + if rt.varBytes-old+size > MaxVariableBytes { + return fmt.Errorf("variable storage limit exceeded (%d bytes total)", rt.varBytes-old+size) + } + rt.varBytes = rt.varBytes - old + size + root.arraySizes[key] = size + root.array[key] = v + return nil +} + func (rt *runtime) replaceArray(name string, elems map[string]value) error { if err := rt.deleteArray(name); err != nil { return err } - if rt.arrays[name] == nil { - rt.arrays[name] = make(map[string]value, len(elems)) - } for key, v := range elems { if err := rt.setArrayElem(name, key, v); err != nil { return err @@ -766,6 +1658,32 @@ func (rt *runtime) replaceArray(name string, elems map[string]value) error { } func (rt *runtime) deleteArrayElem(name, key string) error { + elems, local, globalName, handled, err := rt.ensureLocalArray(name) + if err != nil { + return err + } + if handled { + if globalName != "" { + return rt.deleteGlobalArrayElem(globalName, key) + } + root := rootLocalVar(local) + if root.array == nil { + return nil + } + if old := root.arraySizes[key]; old > 0 { + rt.varBytes -= old + if rt.varBytes < 0 { + rt.varBytes = 0 + } + } + delete(root.arraySizes, key) + delete(elems, key) + return nil + } + return rt.deleteGlobalArrayElem(name, key) +} + +func (rt *runtime) deleteGlobalArrayElem(name, key string) error { rt.ensureBuiltinArray(name) if err := rt.validateArrayName(name); err != nil { return err @@ -784,6 +1702,31 @@ func (rt *runtime) deleteArrayElem(name, key string) error { } func (rt *runtime) deleteArray(name string) error { + _, local, globalName, handled, err := rt.ensureLocalArray(name) + if err != nil { + return err + } + if handled { + if globalName != "" { + return rt.deleteGlobalArray(globalName) + } + root := rootLocalVar(local) + for _, size := range root.arraySizes { + rt.varBytes -= size + } + if rt.varBytes < 0 { + rt.varBytes = 0 + } + root.array = make(map[string]value) + root.arraySizes = make(map[string]int) + root.valueSet = false + root.valueSize = 0 + return nil + } + return rt.deleteGlobalArray(name) +} + +func (rt *runtime) deleteGlobalArray(name string) error { rt.ensureBuiltinArray(name) if err := rt.validateArrayName(name); err != nil { return err @@ -804,24 +1747,39 @@ func (rt *runtime) deleteArray(name string) error { } func (rt *runtime) arrayKeys(name string) ([]string, error) { - rt.ensureBuiltinArray(name) - if err := rt.validateArrayName(name); err != nil { + return rt.arrayKeysSorted(name, false) +} + +func (rt *runtime) arrayKeysSorted(name string, ignoreCase bool) ([]string, error) { + elems, _, _, handled, err := rt.localArrayStorage(name, true) + if err != nil { return nil, err } - rt.markArrayName(name) - keys := make([]string, 0, len(rt.arrays[name])) - for key := range rt.arrays[name] { + if !handled { + rt.ensureBuiltinArray(name) + if err := rt.validateArrayName(name); err != nil { + return nil, err + } + rt.markArrayName(name) + elems = rt.arrays[name] + } + keys := make([]string, 0, len(elems)) + for key := range elems { keys = append(keys, key) } - sortStringKeys(keys) + sortStringKeys(keys, ignoreCase) return keys, nil } -func sortStringKeys(keys []string) { +func sortStringKeys(keys []string, ignoreCase bool) { for i := 1; i < len(keys); i++ { key := keys[i] + sortKey := key + if ignoreCase { + sortKey = strings.ToLower(key) + } j := i - 1 - for j >= 0 && keys[j] > key { + for j >= 0 && compareAwkSortKeys(keys[j], key, sortKey, ignoreCase) > 0 { keys[j+1] = keys[j] j-- } @@ -829,6 +1787,28 @@ func sortStringKeys(keys []string) { } } +func compareAwkSortKeys(left, right, foldedRight string, ignoreCase bool) int { + compareLeft := left + compareRight := right + if ignoreCase { + compareLeft = strings.ToLower(left) + compareRight = foldedRight + } + if compareLeft < compareRight { + return -1 + } + if compareLeft > compareRight { + return 1 + } + if left < right { + return -1 + } + if left > right { + return 1 + } + return 0 +} + func (rt *runtime) ensureBuiltinArray(name string) { if name == "ENVIRON" { rt.ensureEnviron() @@ -851,6 +1831,13 @@ func (rt *runtime) validateArrayName(name string) error { return nil } +func (rt *runtime) markGlobalScalarName(name string) { + if _, ok := rt.vars[name]; !ok { + rt.vars[name] = unassignedValue() + rt.varSizes[name] = 0 + } +} + func isBuiltinScalarName(name string) bool { switch name { case "NF", "NR", "FNR", "FILENAME": @@ -864,6 +1851,19 @@ func isBuiltinArrayName(name string) bool { return name == "ENVIRON" } +func isReservedAwkVariableName(name string) bool { + return isBuiltinScalarName(name) || isBuiltinArrayName(name) || isWritableSpecialScalarName(name) +} + +func isWritableSpecialScalarName(name string) bool { + switch name { + case "FS", "RS", "OFS", "ORS", "SUBSEP", "RSTART", "RLENGTH", "IGNORECASE": + return true + default: + return false + } +} + func validateFS(fs string) error { if fs == " " { return nil @@ -881,6 +1881,16 @@ func validateFS(fs string) error { return nil } +func validateRS(rs string) error { + if rs == "" { + return fmt.Errorf("empty RS is not supported") + } + if !isSingleRune(rs) { + return fmt.Errorf("multi-character RS is not supported") + } + return nil +} + func isSingleRune(s string) bool { if s == "" { return false @@ -889,21 +1899,181 @@ func isSingleRune(s string) bool { return size == len(s) } -func compileRegex(pattern string) (*regexp.Regexp, error) { - normalized := normalizeAwkRegex(pattern) +type awkRegex struct { + re *regexp.Regexp + byteMode bool +} + +func (rt *runtime) compileRegex(pattern string) (*awkRegex, error) { + return compileRegexWithOptions(pattern, rt.ignoreCase()) +} + +func (rt *runtime) ignoreCase() bool { + return rt.getVar("IGNORECASE").Bool() +} + +func compileRegex(pattern string) (*awkRegex, error) { + return compileRegexWithOptions(pattern, false) +} + +func compileRegexWithOptions(pattern string, ignoreCase bool) (*awkRegex, error) { + normalized, byteMode := normalizeAwkRegex(pattern) + if ignoreCase { + normalized = "(?i:" + normalized + ")" + } re, err := regexp.Compile(normalized) if err != nil { return nil, fmt.Errorf("invalid regular expression %q: %v", pattern, err) } re.Longest() - return re, nil + return &awkRegex{re: re, byteMode: byteMode}, nil +} + +func (re *awkRegex) MatchString(s string) bool { + if !re.byteMode { + return re.re.MatchString(s) + } + encoded, _ := encodeAwkRegexBytes(s) + return re.re.MatchString(encoded) +} + +func (re *awkRegex) FindStringIndex(s string) []int { + if !re.byteMode { + return re.re.FindStringIndex(s) + } + encoded, offsets := encodeAwkRegexBytes(s) + loc := re.re.FindStringIndex(encoded) + if loc == nil { + return nil + } + return []int{offsets[loc[0]], offsets[loc[1]]} +} + +func (re *awkRegex) FindStringRuneIndex(s string) []int { + loc := re.FindStringIndex(s) + if loc == nil { + return nil + } + if !re.byteMode { + return []int{runeLen(s[:loc[0]]), runeLen(s[:loc[1]])} + } + start, end := runeRangeForByteRange(s, loc[0], loc[1]) + return []int{start, end} +} + +func (re *awkRegex) FindAllStringIndex(s string, n int) [][]int { + if !re.byteMode { + return re.re.FindAllStringIndex(s, n) + } + encoded, offsets := encodeAwkRegexBytes(s) + matches := re.re.FindAllStringIndex(encoded, n) + for _, loc := range matches { + loc[0] = offsets[loc[0]] + loc[1] = offsets[loc[1]] + } + return matches +} + +func (re *awkRegex) FindStringSubmatchIndex(s string) []int { + loc := re.FindAllStringSubmatchIndex(s, 1) + if len(loc) == 0 { + return nil + } + return loc[0] +} + +func (re *awkRegex) FindAllStringSubmatchIndex(s string, n int) [][]int { + if !re.byteMode { + return re.re.FindAllStringSubmatchIndex(s, n) + } + encoded, offsets := encodeAwkRegexBytes(s) + matches := re.re.FindAllStringSubmatchIndex(encoded, n) + for _, locs := range matches { + for i := 0; i+1 < len(locs); i += 2 { + if locs[i] < 0 { + continue + } + locs[i] = offsets[locs[i]] + locs[i+1] = offsets[locs[i+1]] + } + } + return matches +} + +func runeRangeForByteRange(s string, startByte, endByte int) (int, int) { + if startByte < 0 { + startByte = 0 + } + if startByte > len(s) { + startByte = len(s) + } + if endByte < startByte { + endByte = startByte + } + if endByte > len(s) { + endByte = len(s) + } + if startByte == endByte { + idx := runeIndexForByteOffset(s, startByte) + return idx, idx + } + return runeIndexForByteOffset(s, startByte), runeIndexAfterByteOffset(s, endByte) } -func normalizeAwkRegex(pattern string) string { +func runeIndexForByteOffset(s string, offset int) int { + if offset <= 0 { + return 0 + } + runeIndex := 0 + for i := 0; i < len(s); runeIndex++ { + _, size := utf8.DecodeRuneInString(s[i:]) + next := i + size + if offset < next { + return runeIndex + } + if offset == next { + return runeIndex + 1 + } + i = next + } + return runeIndex +} + +func runeIndexAfterByteOffset(s string, offset int) int { + if offset <= 0 { + return 0 + } + runeIndex := 0 + for i := 0; i < len(s); runeIndex++ { + _, size := utf8.DecodeRuneInString(s[i:]) + next := i + size + if offset <= next { + return runeIndex + 1 + } + i = next + } + return runeIndex +} + +func normalizeAwkRegex(pattern string) (string, bool) { var b strings.Builder + byteMode := awkRegexNeedsByteMode(pattern) for i := 0; i < len(pattern); i++ { ch := pattern[i] if ch != '\\' { + if ch >= 0x80 { + r, size := utf8.DecodeRuneInString(pattern[i:]) + if byteMode || (r == utf8.RuneError && size == 1) { + for j := i; j < i+size; j++ { + writeAwkRegexByteEscape(&b, pattern[j]) + } + i += size - 1 + continue + } + b.WriteString(pattern[i : i+size]) + i += size - 1 + continue + } b.WriteByte(ch) continue } @@ -911,10 +2081,74 @@ func normalizeAwkRegex(pattern string) string { b.WriteByte(ch) continue } + if isOctalDigit(rune(pattern[i+1])) { + value := 0 + for digits := 0; digits < 3 && i+1 < len(pattern) && isOctalDigit(rune(pattern[i+1])); digits++ { + i++ + value = value*8 + int(pattern[i]-'0') + } + writeAwkRegexByteEscape(&b, byte(value)) + continue + } i++ writeAwkRegexEscape(&b, pattern[i]) } - return b.String() + return b.String(), byteMode +} + +func awkRegexNeedsByteMode(pattern string) bool { + for i := 0; i < len(pattern); i++ { + ch := pattern[i] + if ch == '\\' && i+1 < len(pattern) && isOctalDigit(rune(pattern[i+1])) { + value := 0 + for digits := 0; digits < 3 && i+1 < len(pattern) && isOctalDigit(rune(pattern[i+1])); digits++ { + i++ + value = value*8 + int(pattern[i]-'0') + } + if byte(value) >= 0x80 { + return true + } + continue + } + if ch >= 0x80 { + r, size := utf8.DecodeRuneInString(pattern[i:]) + if r == utf8.RuneError && size == 1 { + return true + } + i += size - 1 + } + } + return false +} + +func writeAwkRegexByteEscape(b *strings.Builder, value byte) { + if value >= 0x80 { + const hex = "0123456789abcdef" + b.WriteString(`\x{`) + b.WriteByte(hex[value>>4]) + b.WriteByte(hex[value&0x0f]) + b.WriteByte('}') + return + } + b.WriteByte(value) +} + +func encodeAwkRegexBytes(s string) (string, []int) { + var b strings.Builder + offsets := []int{0} + for i := 0; i < len(s); i++ { + before := b.Len() + if s[i] >= 0x80 { + b.WriteRune(rune(s[i])) + } else { + b.WriteByte(s[i]) + } + for j := before + 1; j < b.Len(); j++ { + offsets = append(offsets, i) + } + offsets = append(offsets, i+1) + } + return b.String(), offsets } func writeAwkRegexEscape(b *strings.Builder, esc byte) { diff --git a/builtins/awk/runtime_test.go b/builtins/awk/runtime_test.go new file mode 100644 index 00000000..7d4db321 --- /dev/null +++ b/builtins/awk/runtime_test.go @@ -0,0 +1,54 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2026-present Datadog, Inc. + +package awk + +import ( + "bytes" + "context" + "io" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/DataDog/rshell/builtins" +) + +type closeTrackedFile struct { + *strings.Reader + closed bool +} + +func (f *closeTrackedFile) Write([]byte) (int, error) { + return 0, os.ErrInvalid +} + +func (f *closeTrackedFile) Close() error { + f.closed = true + return nil +} + +func TestRuntimeClosesInputsOnError(t *testing.T) { + prog, err := parseProgram(`BEGIN { getline x < "input"; print 1 / 0 }`) + require.NoError(t, err) + + opened := &closeTrackedFile{Reader: strings.NewReader("row\n")} + var stderr bytes.Buffer + callCtx := &builtins.CallContext{ + Stderr: &stderr, + OpenFile: func(context.Context, string, int, os.FileMode) (io.ReadWriteCloser, error) { + return opened, nil + }, + } + + result := newRuntime(callCtx, prog).run(context.Background(), nil) + + assert.Equal(t, uint8(1), result.Code) + assert.Contains(t, stderr.String(), "division by zero attempted") + assert.True(t, opened.closed) +} diff --git a/builtins/builtins.go b/builtins/builtins.go index e0b1de90..0b7d06e5 100644 --- a/builtins/builtins.go +++ b/builtins/builtins.go @@ -223,6 +223,12 @@ type CallContext struct { // If nil, callers should fall back to RunCommand. RunCommandWithStdin func(ctx context.Context, dir string, name string, args []string, stdin io.Reader) (uint8, error) + // RunScriptWithStdin executes an rshell script fragment within the shell's + // sandbox, with caller-provided stdin and stdout. Builtins use this for + // language features that accept command strings, so those strings are + // interpreted by rshell rather than by the host shell. + RunScriptWithStdin func(ctx context.Context, dir string, script string, stdin io.Reader, stdout io.Writer) (uint8, error) + // SetVar assigns a value to a shell variable in the calling shell's // scope. Returns an error if the value exceeds the per-variable size // limit or if the total variable-storage cap would be exceeded. diff --git a/builtins/tests/awk/awk_test.go b/builtins/tests/awk/awk_test.go index 86bb0258..e1440595 100644 --- a/builtins/tests/awk/awk_test.go +++ b/builtins/tests/awk/awk_test.go @@ -28,6 +28,32 @@ func runScript(t *testing.T, script, dir string, opts ...interp.RunnerOption) (s return runScriptCtx(context.Background(), t, script, dir, opts...) } +func runScriptRestricted(t *testing.T, script, dir string, opts ...interp.RunnerOption) (string, string, int) { + t.Helper() + parser := syntax.NewParser() + prog, err := parser.Parse(strings.NewReader(script), "") + require.NoError(t, err) + var outBuf, errBuf bytes.Buffer + allOpts := append([]interp.RunnerOption{interp.StdIO(nil, &outBuf, &errBuf)}, opts...) + runner, err := interp.New(allOpts...) + require.NoError(t, err) + defer runner.Close() + if dir != "" { + runner.Dir = dir + } + err = runner.Run(context.Background(), prog) + exitCode := 0 + if err != nil { + var es interp.ExitStatus + if errors.As(err, &es) { + exitCode = int(es) + } else { + t.Fatalf("unexpected error: %v", err) + } + } + return outBuf.String(), errBuf.String(), exitCode +} + func runScriptCtx(ctx context.Context, t *testing.T, script, dir string, opts ...interp.RunnerOption) (string, string, int) { t.Helper() parser := syntax.NewParser() @@ -64,6 +90,25 @@ func writeFile(t *testing.T, dir, name, content string) { require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0644)) } +func TestAwkHelpDescribesSupportedAndUnsupportedProfile(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk --help`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Contains(t, stdout, "Usage: awk [OPTION]... 'program' [FILE]...") + assert.Contains(t, stdout, "This is a practical rshell awk profile, not a full GNU awk clone.") + assert.Contains(t, stdout, "Supported profile:") + assert.Contains(t, stdout, "Output command pipes such as print x | \"sort\"") + assert.Contains(t, stdout, "getline, getline var, getline var < file, and \"cmd\" | getline var") + assert.Contains(t, stdout, "Not supported:") + assert.Contains(t, stdout, "system(). Use supported awk command pipes/getline pipes instead") + assert.Contains(t, stdout, "print/printf file output redirection to file targets") + assert.Contains(t, stdout, "ARGV/ARGC mutation") + assert.Contains(t, stdout, "PROCINFO, SYMTAB, FUNCTAB") + assert.Contains(t, stdout, "gensub, match, strtonum, asorti") + assert.Contains(t, stdout, "asort, patsplit") +} + func TestAwkPrintFields(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "input.txt", "alpha beta gamma\none two three\n") @@ -141,6 +186,61 @@ func TestAwkSplitRegexAndCharacterSeparator(t *testing.T) { assert.Equal(t, "3 a b c\n2 x y\n2 3\n1 2 2 4\n3 a b c\n4 [] [a] [b] []\n3 [] [] []\n", stdout) } +func TestAwkSubGsubMatchAndSprintf(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'BEGIN { s = "abc123def"; print match(s, /[0-9]+/), RSTART, RLENGTH, substr(s, RSTART, RLENGTH); sub(/[0-9]+/, "<&>", s); print s; gsub(/[a-z]+/, "X", s); print s; print sprintf("%s:%03d", "id", 7) }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "4 4 3 123\nabc<123>def\nX<123>X\nid:007\n", stdout) +} + +func TestAwkMatchCapturesGensubStrtonumAndAsorti(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `printf 'cached_tables=31\n' | awk 'match($0, /cached_tables=([0-9]+)/, m) { print m[0], m[1] }'; awk 'BEGIN { print strtonum("0x1538"), strtonum("010"); print strtonum("123abc"), strtonum("-12.5ms"), strtonum("1e3rows"); print strtonum("012.3"), strtonum("012e2"), strtonum("0128"), strtonum("010"); print gensub(/.*trace_id=([0-9]+).*/, "\\1", 1, "trace_id=42"); a["b"] = 2; a["a"] = 1; print asorti(a, k), k[1], k[2]; a[1] = "abc"; print match(a[1], /(b)/, a), RSTART, RLENGTH, a[0], a[1] }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "cached_tables=31 31\n5432 8\n123 -12.5 1000\n12.3 1200 128 8\n42\n2 a b\n2 2 1 b b\n", stdout) +} + +func TestAwkIgnoreCaseAffectsRegexOperations(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `printf 'TypeError\nok\n' | awk 'BEGIN { IGNORECASE = 1 } /typeerror/ { c++ } END { print c + 0 }'; awk 'BEGIN { IGNORECASE = 1; s = "TypeError"; sub(/type/, "Schema", s); print s; print split("AxxB", a, /X+/), a[1], a[2] }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "1\nSchemaError\n2 A B\n", stdout) +} + +func TestAwkByteModeMatchOffsetsUseRunePositions(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'BEGIN { s = "\303\251"; print length(s), "[" s "]"; print match(s, /\251/), RSTART, RLENGTH, "[" substr(s, RSTART, RLENGTH) "]" }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "1 [\303\251]\n1 1 1 [\303\251]\n", stdout) +} + +func TestAwkCompositeKeysAndTernary(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "input.txt", "a x 1\na y 2\na x 3\nb x 4\n") + stdout, stderr, code := cmdRun(t, `awk '{ count[$1, $2] += $3; label = ($3 > 2 ? "big" : "small"); classes[$1, label]++ } END { print count["a", "x"], count["a", "y"], count["b", "x"]; print classes["a", "small"], classes["a", "big"]; delete count["a", "x"]; print (("a", "x") in count), (("b", "x") in count), length(SUBSEP) }' input.txt`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "4 2 4\n2 1\n0 1 1\n", stdout) +} + +func TestAwkExitRunsEndAndPreservesStatus(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "input.txt", "1\n2\n3\n") + stdout, stderr, code := cmdRun(t, `awk '{ if ($1 == 2) exit 7; print $1 } END { print "end", NR }' input.txt`, dir) + assert.Equal(t, 7, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "1\nend 2\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "begin"; exit } { print } END { print "end" }' input.txt`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "begin\nend\n", stdout) +} + func TestAwkForWhileBreakAndContinue(t *testing.T) { dir := t.TempDir() stdout, stderr, code := cmdRun(t, `awk 'BEGIN { for (i = 1; i <= 5; i++) { if (i == 2) continue; if (i == 5) break; sum += i }; j = 0; while (j < 3) { j++; if (j == 2) continue; seen = seen j }; i = 0; for (; i < 3; i++) noinit = noinit i; for (i = 0; i < 3; i++); emptyFor = i; j = 0; while (j++ < 3); emptyWhile = j; print sum, seen, noinit, emptyFor, emptyWhile }'`, dir) @@ -200,6 +300,12 @@ func TestAwkRejectsScalarArrayNameConflicts(t *testing.T) { `awk 'BEGIN { print ENVIRON }'`, `awk 'BEGIN { FS[1] = 2 }'`, `awk 'BEGIN { NF[1] = 2 }'`, + `awk 'function f(x){ x = 1; x[1] = 2 } BEGIN { f(a) }'`, + `awk 'function f(a,b){ a = 2; b[1] = 1 } BEGIN { f(x,x) }'`, + `awk 'function f(x){ x = 1 } BEGIN { f(a); a[1] = 2 }'`, + `awk 'function f(x){ print x; x[1] = 2 } BEGIN { f(a) }'`, + `awk 'function f(x){ print x } BEGIN { f(a); a[1] = 2 }'`, + `awk 'function f(x){ print x; x[1] = 2 } BEGIN { f() }'`, } { _, stderr, code := cmdRun(t, script, dir) assert.Equal(t, 1, code, script) @@ -207,6 +313,104 @@ func TestAwkRejectsScalarArrayNameConflicts(t *testing.T) { } } +func TestAwkRejectsSpecialVariableFunctionNames(t *testing.T) { + dir := t.TempDir() + for _, script := range []string{ + `awk 'function FS(){ return 1 } BEGIN { print FS() }'`, + `awk 'function OFS(){ return 1 } BEGIN { print OFS() }'`, + `awk 'function ORS(){ return 1 } BEGIN { print ORS() }'`, + `awk 'function SUBSEP(){ return 1 } BEGIN { print SUBSEP() }'`, + `awk 'function RSTART(){ return 1 } BEGIN { print RSTART() }'`, + `awk 'function RLENGTH(){ return 1 } BEGIN { print RLENGTH() }'`, + } { + _, stderr, code := cmdRun(t, script, dir) + assert.Equal(t, 1, code, script) + assert.Contains(t, stderr, "reserved awk variable name", script) + } +} + +func TestAwkRejectsSpecialVariableFunctionParameters(t *testing.T) { + dir := t.TempDir() + for _, script := range []string{ + `awk 'function f(FS){ return FS } BEGIN { print f(1) }'`, + `awk 'function f(OFS){ return OFS } BEGIN { print f(1) }'`, + `awk 'function f(ORS){ return ORS } BEGIN { print f(1) }'`, + `awk 'function f(SUBSEP){ return SUBSEP } BEGIN { print f(1) }'`, + `awk 'function f(RSTART){ return RSTART } BEGIN { print f(1) }'`, + `awk 'function f(RLENGTH){ return RLENGTH } BEGIN { print f(1) }'`, + } { + _, stderr, code := cmdRun(t, script, dir) + assert.Equal(t, 1, code, script) + assert.Contains(t, stderr, "reserved awk variable name", script) + } +} + +func TestAwkRejectsUserFunctionNamesAsVariables(t *testing.T) { + dir := t.TempDir() + for _, script := range []string{ + `awk 'function f(){ return 1 } BEGIN { f = 3; print f }'`, + `awk 'function f(){ return 1 } BEGIN { print f }'`, + `awk 'function f(){ return 1 } BEGIN { print $f }'`, + `awk 'function f(){ return 1 } BEGIN { f[1] = 2 }'`, + `awk 'function f(){ return 1 } BEGIN { delete f }'`, + `awk 'function f(){ return 1 } BEGIN { for (f in a) print f }'`, + `awk 'function f(){ return 1 } BEGIN { for (k in f) print k }'`, + `awk 'BEGIN { f = 3 } function f(){ return 1 }'`, + `awk 'function g(){ f = 1 } function f(){ return 1 } BEGIN { g() }'`, + } { + _, stderr, code := cmdRun(t, script, dir) + assert.Equal(t, 1, code, script) + assert.Contains(t, stderr, "cannot be used as a variable or array", script) + } +} + +func TestAwkFunctionParametersMayShadowOtherFunctionNames(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'function f(g){ print g } function g(){ return 1 } BEGIN { f(2); print g() }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "2\n1\n", stdout) +} + +func TestAwkRejectsCallsThroughShadowingParameters(t *testing.T) { + dir := t.TempDir() + for _, script := range []string{ + `awk 'function f(g){ return g() } function g(){ return 1 } BEGIN { print f(2) }'`, + `awk 'function f(g){ print g(1) } function g(x){ return x } BEGIN { f(2) }'`, + } { + _, stderr, code := cmdRun(t, script, dir) + assert.Equal(t, 1, code, script) + assert.Contains(t, stderr, "cannot be called as a function", script) + } +} + +func TestAwkRejectsLoopControlOutsideLexicalLoops(t *testing.T) { + dir := t.TempDir() + for _, tc := range []struct { + script string + err string + }{ + {`awk 'BEGIN { break }'`, "break is not allowed outside a loop"}, + {`awk 'BEGIN { continue }'`, "continue is not allowed outside a loop"}, + {`awk 'function f(){ break } BEGIN { for (i = 0; i < 2; i++) f() }'`, "break is not allowed outside a loop"}, + {`awk 'function f(){ continue } BEGIN { for (i = 0; i < 2; i++) f() }'`, "continue is not allowed outside a loop"}, + {`awk 'function f(){ if (1) { break } } BEGIN { print "unused" }'`, "break is not allowed outside a loop"}, + {`awk 'function f(){ if (1) { continue } } BEGIN { print "unused" }'`, "continue is not allowed outside a loop"}, + } { + _, stderr, code := cmdRun(t, tc.script, dir) + assert.Equal(t, 1, code, tc.script) + assert.Contains(t, stderr, tc.err, tc.script) + } +} + +func TestAwkAllowsLoopControlInsideFunctionLexicalLoops(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'function f(){ out = ""; for (i = 0; i < 4; i++) { if (i == 1) continue; if (i == 3) break; out = out i }; return out } BEGIN { print f() }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "02\n", stdout) +} + func TestAwkExplicitEmptyActionDoesNothing(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "input.txt", "alpha\n") @@ -287,6 +491,14 @@ func TestAwkRegexBracketClassCanContainSlash(t *testing.T) { assert.Equal(t, "/\n", stdout) } +func TestAwkRegexLiteralCanContainRepeatedEquals(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `printf '=== WARM-UP ===\nplain\n' | awk '$0 ~ /===/ { print }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "=== WARM-UP ===\n", stdout) +} + func TestAwkRegexUnknownEscapesBecomeLiterals(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "input.txt", "5\nd\n") @@ -333,6 +545,14 @@ func TestAwkRangePatterns(t *testing.T) { assert.Equal(t, "2:start\n3:middle\n4:end\n6:start end\n", stdout) } +func TestAwkCompoundStatementsSeparateBeforeNextStatement(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'BEGIN { if (1) { x = 1 } print x; for (i = 1; i <= 1; i++) { if (1) y = 2 } print y }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "1\n2\n", stdout) +} + func TestAwkFieldAssignmentAndRecordRebuild(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "input.txt", "a b c\n") @@ -461,6 +681,147 @@ func TestAwkVariablesTabFSAndMultipleFiles(t *testing.T) { assert.Equal(t, "row:one.tsv:1:1:1\nrow:two.tsv:1:2:2\n", stdout) } +func TestAwkSingleCharacterRecordSeparator(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "nul.txt"), []byte("alpha\x00beta\x00"), 0o644)) + writeFile(t, dir, "comma.txt", "x,y,z") + stdout, stderr, code := cmdRun(t, `awk -v RS='\0' '{ print NR ":" $0 }' nul.txt; awk -v RS=, '{ print NR ":" $0 }' comma.txt`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "1:alpha\n2:beta\n1:x\n2:y\n3:z\n", stdout) +} + +func TestAwkCommandPipes(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'BEGIN { print "b" | "sort"; print "a" | "sort"; close("sort"); printf "%s\n", "pipe payload" | "cat"; close("cat") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "a\nb\npipe payload\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "auto-close" | "cat" }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "auto-close\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "b" | "cat"; print "a"; close("cat") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "a\nb\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { cmd = "cat"; print "b" | cmd; print "a"; close(cmd) }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "a\nb\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "x" | "cat"; print "z"; print "y" | "cat" }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "x\ny\nz\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "a" | "wc -l"; printf ""; print "b" | "wc -l"; close("wc -l") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "2\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "b" | "sort"; print "mid"; print "a" | "sort"; close("sort") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "mid\na\nb\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { for (i = 1; i <= 2; i++) { print i | "cat"; print "x" } close("cat") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "x\nx\n1\n2\n", stdout) + + stdout, stderr, code = cmdRun(t, `printf '1\n2\n' | awk '{ print $0 | "cat"; print "x" } END { close("cat") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "x\nx\n1\n2\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'function f(x) { print x | "sort"; print "s" } BEGIN { f("b"); f("a"); close("sort") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "s\ns\na\nb\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "x" | "false" }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "x" | "false"; print "after"; print close("false") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "after\n1\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print "x" | "false"; print close("false") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "1\n", stdout) + + stdout, stderr, code = cmdRun(t, `awk 'BEGIN { print close("missing") }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "-1\n", stdout) +} + +func TestAwkCommandPipesRunNestedRshellScripts(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'BEGIN { cmd = "cat | sort"; print "b" | cmd; print "a" | cmd; print close(cmd) }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "a\nb\n0\n", stdout) +} + +func TestAwkCommandInputPipesUseNestedRshellScripts(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk 'BEGIN { cmd = "printf \"b\\na\\n\" | sort"; print (cmd | getline first), first; print (cmd | getline second), second; print (cmd | getline third), "[" third "]"; print close(cmd); print (cmd | getline again), again }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "1 a\n1 b\n0 []\n0\n1 a\n", stdout) +} + +func TestAwkCommandInputPipesInheritUnopenedStdin(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `printf "outer\n" | awk 'BEGIN { "cat" | getline x; print "x=" x; getline y; print "y=" y }'`, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "x=outer\ny=\n", stdout) +} + +func TestAwkCommandInputPipesKeepStdinWhileReadingFiles(t *testing.T) { + dir := t.TempDir() + input := filepath.Join(dir, "input.txt") + require.NoError(t, os.WriteFile(input, []byte("file-record\n"), 0o644)) + quotedInput := "'" + strings.ReplaceAll(input, "'", `'\''`) + "'" + + stdout, stderr, code := cmdRun(t, `printf "s\n" | awk '{ "cat" | getline x; print "x=" x; print "rec=" $0 }' `+quotedInput, dir) + assert.Equal(t, 0, code) + assert.Equal(t, "", stderr) + assert.Equal(t, "x=s\nrec=file-record\n", stdout) +} + +func TestAwkCommandPipesRespectAllowedCommands(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := runScriptRestricted(t, `awk 'BEGIN { print "x" | "sort" }'`, dir, + interp.AllowedCommands([]string{"rshell:awk"}), + interp.AllowedPaths([]string{dir}), + ) + assert.Equal(t, 1, code) + assert.Equal(t, "", stdout) + assert.Contains(t, stderr, `rshell: sort: command not allowed`) +} + +func TestAwkNestedCommandPipesRespectAllowedCommands(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := runScriptRestricted(t, `awk 'BEGIN { print "x" | "cat | sort" }'`, dir, + interp.AllowedCommands([]string{"rshell:awk", "rshell:cat"}), + interp.AllowedPaths([]string{dir}), + ) + assert.Equal(t, 1, code) + assert.Equal(t, "", stdout) + assert.Contains(t, stderr, `rshell: sort: command not allowed`) +} + func TestAwkOperandAssignments(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "one.txt", "a\n") @@ -482,6 +843,14 @@ func TestAwkOperandAssignments(t *testing.T) { assert.Equal(t, "c\n", stdout) } +func TestAwkMissingInputFileIsFatal(t *testing.T) { + dir := t.TempDir() + stdout, stderr, code := cmdRun(t, `awk '{ print }' missing.txt`, dir) + assert.Equal(t, 2, code) + assert.Equal(t, "", stdout) + assert.Contains(t, stderr, "awk: fatal: cannot open file `missing.txt' for reading:") +} + func TestAwkAppliesFieldSeparatorOptionsInOrder(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "input.txt", "a:b,c\n") @@ -508,9 +877,7 @@ func TestAwkRejectsUnsafeFeatures(t *testing.T) { `awk '{ system("sh") }' input.txt`, `awk '{ print $1 > "out" }' input.txt`, `awk '{ printf "%s", $1 > "out" }' input.txt`, - `awk '{ print getline }' input.txt`, `awk '{ x = next }' input.txt`, - `awk '{ exit 0 }' input.txt`, `awk 'BEGIN { next }' input.txt`, `awk 'BEGIN { print tolower(), toupper(), int() }' input.txt`, `awk '{ print int() }' empty.txt`, diff --git a/docs/AWK_IMPLEMENTATION_PLAN.md b/docs/AWK_IMPLEMENTATION_PLAN.md index eefce62f..8673e41d 100644 --- a/docs/AWK_IMPLEMENTATION_PLAN.md +++ b/docs/AWK_IMPLEMENTATION_PLAN.md @@ -245,15 +245,19 @@ The builtin must preserve rshell's no-write, no-host-exec safety model. Reject or defer: - `system()` -- command pipes: `print | "cmd"` and `"cmd" | getline` - coprocesses - output redirection to files: `print > "file"` and `print >> "file"` -- `getline` in all forms for Phase 1 - dynamic extension loading - network special files - any feature that executes host commands - any feature that writes, creates, modifies, or deletes files +Output and input command pipes such as `print ... | "sort"` and +`"printf \"b\\na\\n\" | sort" | getline line` are permitted in Phase 4 only +through rshell's controlled builtin execution model. They do not invoke a host +shell; command strings are parsed and executed by rshell, so the normal command +allowlist, path policy, and parser restrictions still apply. + All file reads must go through `callCtx.OpenFile`. ## Implementation Files @@ -390,12 +394,35 @@ Implementation order used by `codex/awk-phase-3`: Phase 4 candidates: -- user-defined functions -- additional POSIX awk builtins -- carefully restricted `getline`, only if a safe design is approved -- safe command pipes through rshell's controlled execution model, only if a - concrete non-host-escape design is approved -- safe GNU awk compatibility extensions that do not violate rshell policy +Phase 4 should make the builtin investigation-grade for LLM-generated awk +programs without attempting a full GNU awk clone. Prioritize features that +unlock common log, table, and small-report workflows: + +- regex text editing and extraction: `sub`, `gsub`, `gensub`, `match`, + capture arrays, `RSTART`, and `RLENGTH` +- expression formatting: `sprintf` +- composite array keys with `SUBSEP`, such as `count[$1, $2]++` +- compact expression/control ergonomics: ternary `cond ? a : b`, `exit [code]`, + and, if it remains small, `do ... while` +- user-defined functions with `return`; array parameters are preferred over a + scalar-only subset because practical helper functions often receive arrays +- safe command pipes such as `print ... | "sort"`, `"cmd" | getline line`, + and `close(cmd)`, implemented only through rshell's controlled builtin + execution model +- practical `getline` forms that read from the current input stream or from + files through `callCtx.OpenFile` +- focused utility builtins that support investigations, starting with + `strtonum` for `/proc/net/*` hex decoding and `asorti` for deterministic + reports +- practical record splitting controls such as single-character `RS`, + including NUL for `/proc/*/cmdline` and `/proc/*/environ`, plus + `IGNORECASE` for case-insensitive log scans + +Defer or reject low-value or high-risk GNU awk compatibility surfaces: +`system()`, unrestricted file redirection, `PROCINFO`, `SYMTAB`, `FUNCTAB`, +namespaces, `include`, `load`, `FIELDWIDTHS`, +`FPAT`, CSV mode, math/time/random builtins, i18n builtins, bitwise builtins, +and broad introspection. ## Open Design Questions diff --git a/interp/runner_exec.go b/interp/runner_exec.go index c35a9e18..bb8a343d 100644 --- a/interp/runner_exec.go +++ b/interp/runner_exec.go @@ -556,6 +556,50 @@ func (r *Runner) call(ctx context.Context, pos syntax.Pos, args []string) { }) } var runCmdWithStdin func(context.Context, string, string, []string, io.Reader) (uint8, error) + var runScriptWithStdin func(context.Context, string, string, io.Reader, io.Writer) (uint8, error) + runScriptWithStdin = func(ctx context.Context, dir string, script string, childStdin io.Reader, childStdout io.Writer) (uint8, error) { + prog, err := ParseScript(script, "awk-command") + if err != nil { + return 2, err + } + childStdinFile, err := stdinFile(ctx, childStdin) + if err != nil { + return 1, err + } + if original, ok := childStdin.(*os.File); !ok || original != childStdinFile { + if childStdinFile != nil { + defer childStdinFile.Close() + } + } + if childStdout == nil { + childStdout = io.Discard + } + child := r.subshell(false) + if dir != "" { + child.Dir = dir + } + child.stdin = childStdinFile + child.stdout = childStdout + child.stderr = r.stderr + child.runStdin = childStdinFile + child.runStdout = childStdout + child.inPipeline = false + child.exit = exitStatus{} + child.stmts(ctx, prog.Stmts) + child.exit.exiting = false + + r.totalCount += child.totalCount + r.dispatchedCount += child.dispatchedCount + r.unallowedCount += child.unallowedCount + r.unknownCount += child.unknownCount + if child.exit.fatalExit { + return child.exit.code, child.exit.err + } + if child.unallowedCount > 0 { + return child.exit.code, fmt.Errorf("nested command not allowed") + } + return child.exit.code, nil + } runCmdWithStdin = func(ctx context.Context, dir string, cmdName string, cmdArgs []string, childStdin io.Reader) (uint8, error) { if !r.allowAllCommands && !r.allowedCommands[cmdName] { return 127, fmt.Errorf("rshell: %s: command not allowed", cmdName) @@ -643,6 +687,7 @@ func (r *Runner) call(ctx context.Context, pos syntax.Pos, args []string) { return runCmdWithStdin(ctx, dir, name, args, childStdin) }, RunCommandWithStdin: runCmdWithStdin, + RunScriptWithStdin: runScriptWithStdin, // Intentionally not exposing SetVar / GetVar in the // child CallContext used for find -exec / -execdir // grandchildren. find treats each invocation as a @@ -742,6 +787,7 @@ func (r *Runner) call(ctx context.Context, pos syntax.Pos, args []string) { }, RunCommand: runCmd, RunCommandWithStdin: runCmdWithStdin, + RunScriptWithStdin: runScriptWithStdin, SetVar: func(name, value string) error { if len(value) > MaxVarBytes { return fmt.Errorf("%s: value too large (limit %d bytes)", name, MaxVarBytes) diff --git a/tests/awk_scenarios/enabled.txt b/tests/awk_scenarios/enabled.txt index 338cf37a..e979ee88 100644 --- a/tests/awk_scenarios/enabled.txt +++ b/tests/awk_scenarios/enabled.txt @@ -1,66 +1,224 @@ +gawk/arrays/aliased_array_params_share_updates.yaml +gawk/arrays/array_creation_through_nested_call.yaml +gawk/arrays/array_parameter_delete_iteration.yaml +gawk/arrays/array_reference_side_effect.yaml +gawk/arrays/asorti_ignorecase_index_order.yaml gawk/arrays/associative_count.yaml gawk/arrays/delete_index.yaml +gawk/arrays/delete_local_array_parameter.yaml +gawk/arrays/delete_parameter_reuse.yaml +gawk/arrays/empty_key_global_alias.yaml +gawk/arrays/getline_delete_array_reuse.yaml +gawk/arrays/getline_empty_array_element_redirection.yaml +gawk/arrays/global_parameter_array_updates.yaml +gawk/arrays/in_operator.yaml +gawk/arrays/local_array_reuse_after_scalar_parameter.yaml +gawk/arrays/missing_argument_passed_as_scalar.yaml +gawk/arrays/multidim_table_slots.yaml +gawk/arrays/numeric_string_subscript_preserves_lexeme.yaml +gawk/arrays/numeric_subscript_convfmt_stability.yaml +gawk/arrays/numeric_subscript_debug_classification.yaml +gawk/arrays/numeric_test_on_unassigned_element.yaml +gawk/arrays/repeated_split_after_array_delete.yaml +gawk/arrays/split_into_array_parameter.yaml +gawk/arrays/split_local_array_after_scalar_buffer.yaml +gawk/arrays/string_numeric_subscript.yaml gawk/arrays/subscript_name_keeps_scalar_value.yaml +gawk/arrays/template_substitution_marker_arrays.yaml gawk/arrays/unassigned_subscript_empty_string.yaml gawk/basic/begin_end_records.yaml gawk/basic/field_separator.yaml +gawk/cli/binmode_variable_assignment.yaml +gawk/cli/terminal_backslash_argument.yaml +gawk/control/exit_runs_end.yaml gawk/control/for_loop_fields.yaml gawk/control/if_else.yaml gawk/control/while_break.yaml gawk/expressions/appended_numeric_string_reconverts.yaml gawk/expressions/arithmetic_comparison.yaml gawk/expressions/concat_literal_punctuation.yaml +gawk/expressions/concat_parenthesized_uninitialized.yaml +gawk/expressions/conditional_operator.yaml +gawk/expressions/concat_after_getline_index.yaml +gawk/expressions/function_local_concat.yaml +gawk/expressions/function_parameter_concatenation_copy.yaml gawk/expressions/leading_digit_exponent_fragment.yaml gawk/expressions/negative_fraction_integer_format.yaml gawk/expressions/nondecimal_string_parameter.yaml gawk/expressions/numeric_string_division.yaml gawk/expressions/numeric_substr_padding.yaml +gawk/expressions/saved_record_string_compare.yaml gawk/expressions/string_concatenation.yaml gawk/expressions/string_constant_numeric_comparison.yaml gawk/expressions/string_field_number_reference.yaml +gawk/expressions/string_numeric_compare.yaml gawk/expressions/unary_minus_string_operand.yaml gawk/expressions/unary_plus_preserves_decimal_string_value.yaml gawk/fields/assign_rebuilds_record.yaml gawk/fields/empty_field_assignment_preserves_nf.yaml +gawk/fields/gsub_assignment_resplits_record.yaml gawk/fields/nf_assignment.yaml gawk/fields/numeric_field_terminator.yaml +gawk/fields/substitution_then_field_assignment.yaml +gawk/functions/array_parameter_reuse.yaml +gawk/functions/comma_formatting.yaml +gawk/functions/delete_array_inside_for_loop.yaml +gawk/functions/delete_array_parameter_elements.yaml +gawk/functions/delete_whole_array_parameter.yaml +gawk/functions/function_semicolon_newline.yaml +gawk/functions/getline_current_input.yaml +gawk/functions/length_array_parameter.yaml +gawk/functions/match_position.yaml +gawk/functions/nested_function_stack_arrays.yaml gawk/functions/printf_width_precision_mix.yaml +gawk/functions/scalar_parameter_does_not_alias_global.yaml gawk/functions/split.yaml gawk/functions/split_default_separator.yaml gawk/functions/string_core.yaml +gawk/functions/tail_recursive_array_argument.yaml +gawk/input/exit_end_bare_preserves_status.yaml +gawk/input/exit_end_status_override.yaml +gawk/input/exit_expression_stops_begin.yaml +gawk/input/function_call_arg_exit_begin.yaml +gawk/input/function_call_arg_exit_record.yaml +gawk/input/getline_after_marker_long_record.yaml +gawk/input/getline_after_marker_variable.yaml +gawk/input/getline_array_index_eof.yaml +gawk/input/getline_begin_reads_argv_files.yaml +gawk/input/getline_directory_error.yaml +gawk/input/getline_eof_after_fs_change.yaml +gawk/input/getline_field_increment_syntax.yaml +gawk/input/getline_target_expression_stdin.yaml gawk/input/no_trailing_newline_regex.yaml gawk/input/nr_concat_builtin_records.yaml +gawk/input/nr_concat_end_block.yaml +gawk/io/close_current_filename_not_redirection.yaml +gawk/io/close_missing_input_redirection.yaml +gawk/io/end_block_close_reopens_file.yaml +gawk/io/getline_extra_expression.yaml +gawk/io/input_redirection_precedence.yaml +gawk/io/paragraph_backslash_fs.yaml +gawk/io/paragraph_split_uses_fs.yaml +gawk/io/reparse_saved_record_fields.yaml gawk/misc/assign_extends_record.yaml gawk/misc/begin_print_hello.yaml +gawk/misc/byte_range_regex_c_locale.yaml gawk/misc/compound_assignment_subscript_side_effect.yaml +gawk/misc/concat_uses_left_value_before_function_side_effect.yaml +gawk/misc/getline_preserves_parameter_copy.yaml gawk/misc/in_operator_assignment_value.yaml gawk/misc/last_field_concat_once.yaml gawk/misc/nested_self_compound_assignment.yaml +gawk/misc/nul_string_comparison.yaml +gawk/misc/print_argument_function_output_order.yaml +gawk/misc/print_evaluates_function_result_once.yaml +gawk/misc/printf_argument_value_before_function_side_effect.yaml gawk/misc/printf_plus_flag_decimal.yaml gawk/misc/range_pattern_boundaries.yaml +gawk/misc/sub_complex_regex_no_loop_double_quote.yaml +gawk/misc/sub_complex_regex_no_loop_embedded_quote.yaml gawk/output/hex_input_numeric_conversion.yaml +gawk/output/hex_literal_token_boundaries.yaml gawk/output/integer_precision_padding.yaml +gawk/output/multibyte_char_width_precision.yaml +gawk/output/multibyte_field_alignment.yaml +gawk/output/multibyte_left_width.yaml +gawk/output/multibyte_percent_c_numeric_string.yaml +gawk/output/multibyte_printf_roundtrip.yaml +gawk/output/ofmt_big_numeric_extrema.yaml +gawk/output/ofmt_directory_extrema.yaml +gawk/output/ofmt_string_format_preserves_fields.yaml +gawk/output/ofmt_strnum_keeps_original_text.yaml gawk/output/print_separators.yaml +gawk/output/printf_c_array_index_is_string.yaml +gawk/output/printf_floating_flag_grid.yaml gawk/output/printf_format.yaml gawk/output/printf_zero_precision_hex_resets_alternate.yaml +gawk/output/sprintf_c_conversion_records.yaml +gawk/output/sprintf_value.yaml gawk/output/zero_flag_ignored_with_integer_precision.yaml +gawk/records/begin_field_arg_before_record_reassign.yaml +gawk/records/command_line_fs_space_colon_plus.yaml +gawk/records/empty_string_array_index.yaml +gawk/records/fieldwidths_disabled_by_fs_assignment.yaml +gawk/records/fs_alternation_start_anchor_empty_field.yaml +gawk/records/fs_caret_dot_rebuild.yaml gawk/records/fs_single_backslash.yaml +gawk/records/fs_tab_plus_repeated_tabs.yaml +gawk/records/function_arg_before_record_reassign.yaml +gawk/records/nf_assignment_truncates_and_extends.yaml +gawk/records/nf_extension_loop_rebuild.yaml +gawk/records/nf_increment_preserves_function_parameter.yaml +gawk/records/nul_fs_string_split.yaml +gawk/records/resplit_record_after_fs_change.yaml +gawk/regex/array_subscript_divide_assignment.yaml +gawk/regex/backslash_big_s_nonspace.yaml +gawk/regex/backslash_small_s_repetition.yaml +gawk/regex/backslash_small_s_single_whitespace.yaml +gawk/regex/backslash_w_word_match.yaml +gawk/regex/dfa_anchored_repetition_backtracking.yaml gawk/regex/dfa_nested_closure_alternation.yaml gawk/regex/escaped_left_brace_literal.yaml +gawk/regex/gensub_record_self_assignment.yaml +gawk/regex/gsub_end_anchor_alternation.yaml +gawk/regex/gsub_field_no_match_preserves_record.yaml +gawk/regex/gsub_ofs_target_affects_print_separator.yaml +gawk/regex/gsub_punctuation_bracket_class.yaml +gawk/regex/gsub_replacement.yaml +gawk/regex/match_empty_string_utf8_locale.yaml +gawk/regex/match_last_field_dynamic_regex.yaml +gawk/regex/match_nullable_uninitialized.yaml +gawk/regex/match_uninitialized_empty_values.yaml gawk/regex/pattern_match.yaml +gawk/regex/sub_ampersand.yaml +gawk/regex/sub_escaped_ampersand.yaml +gawk/regex/sub_multibyte_repeated_substr.yaml +gawk/string_regex/bracket_range_edge_cases.yaml +gawk/string_regex/eight_bit_bracket_backtracking.yaml +gawk/string_regex/escaped_punctuation_bracket_substitution.yaml +gawk/string_regex/ignorecase_numeric_string_truth.yaml +gawk/string_regex/ignorecase_posix_alnum_class.yaml +gawk/string_regex/independent_regex_operator_precedence.yaml +gawk/string_regex/letter_range_membership.yaml +gawk/string_regex/long_prefix_substitution.yaml +gawk/string_regex/long_words_regex_collection.yaml +gawk/string_regex/multibyte_match_substr_offsets.yaml +gawk/string_regex/negative_dash_range_separator.yaml +gawk/string_regex/nul_dynamic_regexp_operators.yaml +gawk/string_regex/numeric_string_array_keys.yaml +gawk/string_regex/octal_numeric_subscript.yaml +gawk/string_regex/reparse_after_record_rebuild.yaml +gawk/string_regex/space_and_blank_classes.yaml +gawk/string_regex/split_destination_aliases_source.yaml +gawk/string_regex/split_dynamic_separator_variable.yaml +gawk/string_regex/split_space_string_vs_regexp.yaml +gawk/string_regex/strnum_string_format_preserved.yaml +gawk/string_regex/strtod_hex_prefix_and_zero_strings.yaml +gawk/text/index_updates_after_substitution.yaml +gawk/text/getline_swaps_adjacent_lines.yaml +gawk/text/numeric_subsep_composite_key.yaml gawk/text/print_records_verbatim.yaml +gawk/text/repeated_sub_extracts_quoted_values.yaml +gawk/text/substitution_refreshes_index_offsets.yaml +gawk/text/utf8_index_after_getline_concat.yaml +gawk/text/valgrind_log_scanner_reports_loss.yaml +onetrueawk/arrays/delete_composite_subscripts.yaml onetrueawk/arrays/delete_current_key.yaml onetrueawk/arrays/first_seen_totals.yaml onetrueawk/arrays/record_storage_split.yaml onetrueawk/arrays/regex_bucket_counts.yaml +onetrueawk/arrays/split_membership_in.yaml onetrueawk/arrays/unique_field_counts.yaml onetrueawk/basic/begin_filename_and_end_nr.yaml onetrueawk/basic/comments_ignored.yaml onetrueawk/basic/pattern_action.yaml onetrueawk/basic/record_counter_nr.yaml +onetrueawk/control/begin_getline_exit.yaml +onetrueawk/control/begin_getline_then_main.yaml onetrueawk/control/division_loop_variants.yaml onetrueawk/control/for_each_field_reverse.yaml onetrueawk/control/infinite_for_next_record.yaml +onetrueawk/core/assert_function_return_comparison.yaml onetrueawk/core/assign_existing_field_constant.yaml onetrueawk/core/assign_first_field_from_nr.yaml onetrueawk/core/assign_last_field_from_nr.yaml @@ -68,6 +226,7 @@ onetrueawk/core/assign_record_from_second_field.yaml onetrueawk/core/break_end_stored_records.yaml onetrueawk/core/break_inner_loop_only.yaml onetrueawk/core/break_preserves_matching_element.yaml +onetrueawk/core/concat_with_preincrement.yaml onetrueawk/core/continue_skips_numeric_fields.yaml onetrueawk/core/custom_ors_without_final_newline.yaml onetrueawk/core/delete_numeric_and_string_keys.yaml @@ -75,6 +234,7 @@ onetrueawk/core/delete_split_element_count.yaml onetrueawk/core/dynamic_field_zero_or_one_assignment.yaml onetrueawk/core/dynamic_first_field_division.yaml onetrueawk/core/end_record_count.yaml +onetrueawk/core/exit_from_function_runs_end.yaml onetrueawk/core/field_assignment_rebuild_marker.yaml onetrueawk/core/field_reference_order.yaml onetrueawk/core/first_seen_amount_totals.yaml @@ -83,13 +243,26 @@ onetrueawk/core/for_in_counts_and_total.yaml onetrueawk/core/for_increment_expression_sums_fields.yaml onetrueawk/core/for_loop_multiline_clauses.yaml onetrueawk/core/for_loop_next_after_fields.yaml +onetrueawk/core/function_arity_unused_args.yaml +onetrueawk/core/function_order_field_access.yaml +onetrueawk/core/function_parameter_locality.yaml +onetrueawk/core/function_side_effect_before_return_concat.yaml +onetrueawk/core/function_split_array_argument.yaml +onetrueawk/core/gsub_default_record_vowels.yaml +onetrueawk/core/gsub_dynamic_char_class_ampersand.yaml +onetrueawk/core/gsub_dynamic_first_character.yaml +onetrueawk/core/gsub_end_anchor_appends.yaml onetrueawk/core/if_truthy_fields.yaml onetrueawk/core/inline_comments_inside_action.yaml +onetrueawk/core/match_function_sets_offsets.yaml +onetrueawk/core/missing_later_field_empty.yaml onetrueawk/core/next_skips_later_action.yaml onetrueawk/core/not_operator_patterns.yaml onetrueawk/core/numeric_field_comparison_pattern.yaml onetrueawk/core/numeric_literal_regex_pattern.yaml onetrueawk/core/or_pattern_with_regex.yaml +onetrueawk/core/overlapping_range_patterns.yaml +onetrueawk/core/postincrement_dynamic_field_sum.yaml onetrueawk/core/prefix_postfix_increment_counters.yaml onetrueawk/core/range_pattern_basic.yaml onetrueawk/core/regex_bracket_classes_dynamic.yaml @@ -99,9 +272,14 @@ onetrueawk/core/running_sum_and_final_total.yaml onetrueawk/core/same_regex_range_records.yaml onetrueawk/core/split_fields_reordered.yaml onetrueawk/core/split_reuses_source_array.yaml +onetrueawk/core/sub_and_gsub_replacement_forms.yaml +onetrueawk/core/sub_last_character.yaml +onetrueawk/core/substr_key_accumulation.yaml +onetrueawk/core/substr_nonpositive_range.yaml onetrueawk/core/tt01_print_records.yaml onetrueawk/core/tt02_nr_nf_record.yaml onetrueawk/core/tt03_sum_second_field_lengths.yaml +onetrueawk/core/tt04_reverse_fields_printf.yaml onetrueawk/core/tt05_reverse_fields_string.yaml onetrueawk/core/tt06_group_lengths_for_in.yaml onetrueawk/core/tt07_even_field_count_pattern.yaml @@ -111,7 +289,11 @@ onetrueawk/core/tt10_nonempty_end_pattern.yaml onetrueawk/core/tt11_fixed_substr.yaml onetrueawk/core/tt12_field_string_and_decrement.yaml onetrueawk/core/tt13_store_fields_in_array.yaml +onetrueawk/core/tt15_small_formatter_functions.yaml +onetrueawk/core/tt16_word_counts_without_sort.yaml +onetrueawk/core/uninitialized_and_empty_field_comparisons.yaml onetrueawk/core/uninitialized_concat_prefix.yaml +onetrueawk/expressions/builtin_numeric_coercions.yaml onetrueawk/expressions/number_string_conversion.yaml onetrueawk/expressions/numeric_string_exclusions.yaml onetrueawk/expressions/string_range_comparisons.yaml @@ -126,30 +308,68 @@ onetrueawk/fields/nf_assignment_rebuild.yaml onetrueawk/fields/regex_field_separator_tabs.yaml onetrueawk/fields/set_record_from_field.yaml onetrueawk/fixtures/t_1_x_concatenated_assignment.yaml +onetrueawk/fixtures/t_2_x_field_assignment_preserves_saved_value.yaml +onetrueawk/fixtures/t_3_x_division_loop.yaml onetrueawk/fixtures/t_4_x_parenthesized_field_reference.yaml +onetrueawk/fixtures/t_5_x_dynamic_first_field_assignment.yaml onetrueawk/fixtures/t_6_x_nf_and_record_printing.yaml +onetrueawk/fixtures/t_8_x_second_field_creation_on_empty_record.yaml +onetrueawk/fixtures/t_8_y_first_field_from_missing_second.yaml onetrueawk/fixtures/t_d_x_colon_separator_nf.yaml +onetrueawk/fixtures/t_format4_sprintf_width_substr.yaml +onetrueawk/fixtures/t_intest2_composite_membership.yaml onetrueawk/fixtures/t_longstr_literal_preserved.yaml +onetrueawk/fixtures/t_makef_assign_third_field.yaml onetrueawk/fixtures/t_monotone_optional_regex_chain.yaml +onetrueawk/fixtures/t_nameval_first_seen_totals.yaml +onetrueawk/fixtures/t_pipe_print_to_command.yaml onetrueawk/fixtures/t_quote_field_with_literal_quotes.yaml +onetrueawk/fixtures/t_reg_bracket_regexes.yaml +onetrueawk/fixtures/t_roff_word_wrap_state.yaml onetrueawk/fixtures/t_sep_digit_field_separator.yaml onetrueawk/fixtures/t_seqno_record_numbers.yaml +onetrueawk/fixtures/t_split8_regex_whitespace_split.yaml +onetrueawk/fixtures/t_split9_fs_split.yaml +onetrueawk/fixtures/t_split9a_literal_fs_split.yaml onetrueawk/fixtures/t_stately_grouped_alternation_repetition.yaml +onetrueawk/fixtures/t_time_suffix_records_summary.yaml +onetrueawk/fixtures/t_vf1_iterate_fields.yaml +onetrueawk/fixtures/t_vf2_postincrement_last_field.yaml +onetrueawk/fixtures/t_vf3_dynamic_field_assignment.yaml onetrueawk/fixtures/t_vf_dynamic_field_read.yaml onetrueawk/fixtures/t_x_regex_default_print.yaml +onetrueawk/fixtures/tt_02a_second_field_length_assignment.yaml onetrueawk/fixtures/tt_03a_third_field_sum.yaml onetrueawk/fixtures/tt_10a_dynamic_dot_end_regex.yaml +onetrueawk/fixtures/tt_13a_numbered_field_snapshot.yaml +onetrueawk/fixtures/tt_big_multi_action_program.yaml +onetrueawk/functions/array_parameter_split.yaml +onetrueawk/functions/field_arguments_are_values.yaml +onetrueawk/functions/function_numeric_loop.yaml onetrueawk/functions/index_substring_positions.yaml onetrueawk/functions/split_default_fields.yaml onetrueawk/functions/split_dynamic_separator.yaml onetrueawk/functions/split_regex_separator.yaml +onetrueawk/functions/sub_ampersand_replacement.yaml +onetrueawk/functions/sub_string_pattern.yaml onetrueawk/functions/substr_pattern_filters.yaml +onetrueawk/input/getline_groups_records.yaml onetrueawk/output/custom_ofs.yaml onetrueawk/output/ofs_ors_print.yaml onetrueawk/output/printf_numeric_formats.yaml +onetrueawk/output/printf_sprintf_width.yaml +onetrueawk/programs/chemical_formula_atom_counts.yaml onetrueawk/programs/constant_string_concatenation.yaml onetrueawk/programs/delete_element_and_array.yaml +onetrueawk/programs/dynamic_regex_cache_sub_replacement.yaml +onetrueawk/programs/expression_precedence_and_numeric_strings.yaml onetrueawk/programs/expression_result_numeric_conversion.yaml +onetrueawk/programs/field_separator_option_variants.yaml +onetrueawk/programs/gawk_backslash_gsub_and_reparse.yaml +onetrueawk/programs/getline_variable_preserves_record.yaml +onetrueawk/programs/interval_expression_boundaries.yaml +onetrueawk/programs/large_string_fields_and_array_delete.yaml +onetrueawk/programs/misc_record_rebuild_and_end_state.yaml onetrueawk/programs/p01_print_records.yaml onetrueawk/programs/p02_print_selected_fields.yaml onetrueawk/programs/p03_printf_columns.yaml @@ -179,6 +399,7 @@ onetrueawk/programs/p26_accumulate_asia_long_assignment.yaml onetrueawk/programs/p26a_accumulate_asia_compound_assignment.yaml onetrueawk/programs/p27_maximum_numeric_field.yaml onetrueawk/programs/p28_nr_colon_record_concat.yaml +onetrueawk/programs/p29_gsub_record_default_target.yaml onetrueawk/programs/p30_length_builtin_current_record.yaml onetrueawk/programs/p31_longest_first_field.yaml onetrueawk/programs/p32_substr_field_rebuild.yaml @@ -190,16 +411,28 @@ onetrueawk/programs/p37_concatenated_field_equality.yaml onetrueawk/programs/p38_block_if_maximum.yaml onetrueawk/programs/p39_while_print_each_field.yaml onetrueawk/programs/p40_for_print_each_field.yaml +onetrueawk/programs/p41_exit_before_end_line_count.yaml onetrueawk/programs/p42_array_accumulate_regex_buckets.yaml onetrueawk/programs/p43_area_by_group_for_in.yaml +onetrueawk/programs/p44_recursive_factorial_function.yaml onetrueawk/programs/p45_ofs_ors_print.yaml onetrueawk/programs/p46_adjacent_field_concatenation.yaml +onetrueawk/programs/p48_array_totals_piped_sort.yaml +onetrueawk/programs/p50_composite_key_piped_sort.yaml +onetrueawk/programs/p51_grouped_colon_report.yaml +onetrueawk/programs/p52_grouped_totals_report.yaml onetrueawk/programs/p5a_tabular_header_printf.yaml +onetrueawk/programs/p_table_simple_formatter.yaml +onetrueawk/programs/recursive_functions_and_array_params.yaml onetrueawk/programs/regular_expression_operator_matrix.yaml onetrueawk/programs/split_empty_separator_and_fs_reparse.yaml +onetrueawk/programs/sub_gsub_replacement_edges.yaml +onetrueawk/programs/utf8_length_index_substr_printf.yaml +onetrueawk/programs/utf8_regular_expression_matches.yaml onetrueawk/records/longest_record.yaml onetrueawk/records/modulo_pattern_default_print.yaml onetrueawk/records/sum_count_average.yaml +onetrueawk/regex/array_regex_patterns.yaml onetrueawk/regex/compound_pattern_conditions.yaml onetrueawk/regex/dynamic_regex_from_field.yaml onetrueawk/regex/dynamic_regex_literals.yaml diff --git a/tests/awk_scenarios/gawk/arrays/delete_parameter_reuse.yaml b/tests/awk_scenarios/gawk/arrays/delete_parameter_reuse.yaml index 884b9902..32ca4275 100644 --- a/tests/awk_scenarios/gawk/arrays/delete_parameter_reuse.yaml +++ b/tests/awk_scenarios/gawk/arrays/delete_parameter_reuse.yaml @@ -22,14 +22,14 @@ input: BEGIN { clear(table) fill(table) - for (key in table) - print key, table[key] + print "one", table["one"] + print "two", table["two"] clear(table) print length(table) } expect: - stdout_contains: - - one 1 - - two 2 - - "0" + stdout: | + one 1 + two 2 + 0 exit_code: 0 diff --git a/tests/awk_scenarios/gawk/arrays/getline_empty_array_element_redirection.yaml b/tests/awk_scenarios/gawk/arrays/getline_empty_array_element_redirection.yaml index cb4ac132..dd0a3564 100644 --- a/tests/awk_scenarios/gawk/arrays/getline_empty_array_element_redirection.yaml +++ b/tests/awk_scenarios/gawk/arrays/getline_empty_array_element_redirection.yaml @@ -6,6 +6,7 @@ upstream: covers: - missing array element redirection operands evaluate to the empty string - getline input redirection rejects a null filename +oracle_stderr_skip: rshell emits compact fatal diagnostics without GNU awk command-line prefixes. input: program: | BEGIN { diff --git a/tests/awk_scenarios/gawk/input/getline_field_increment_syntax.yaml b/tests/awk_scenarios/gawk/input/getline_field_increment_syntax.yaml index de98fffc..bc35e3e9 100644 --- a/tests/awk_scenarios/gawk/input/getline_field_increment_syntax.yaml +++ b/tests/awk_scenarios/gawk/input/getline_field_increment_syntax.yaml @@ -6,6 +6,7 @@ upstream: covers: - getline requires an assignable target expression - repeated post-increment operators after a field reference are a syntax error +oracle_stderr_skip: rshell emits compact parser diagnostics without GNU awk caret rendering. input: program: | BEGIN { diff --git a/tests/awk_scenarios_test.go b/tests/awk_scenarios_test.go index d9759bdf..bff96705 100644 --- a/tests/awk_scenarios_test.go +++ b/tests/awk_scenarios_test.go @@ -24,13 +24,14 @@ import ( ) type awkScenario struct { - Description string `yaml:"description"` - Upstream awkUpstreamMetadata `yaml:"upstream"` - Covers []string `yaml:"covers"` - Skip string `yaml:"skip"` - Setup setup `yaml:"setup"` - Input awkInput `yaml:"input"` - Expect awkExpected `yaml:"expect"` + Description string `yaml:"description"` + Upstream awkUpstreamMetadata `yaml:"upstream"` + Covers []string `yaml:"covers"` + Skip string `yaml:"skip"` + OracleStderrSkip string `yaml:"oracle_stderr_skip"` + Setup setup `yaml:"setup"` + Input awkInput `yaml:"input"` + Expect awkExpected `yaml:"expect"` } type awkUpstreamMetadata struct { @@ -140,7 +141,9 @@ func TestAwkScenarios(t *testing.T) { want := runAwkScenario(t, oracle, sc, timeout) assert.Equal(t, want.exitCode, got.exitCode, "exit code mismatch against GNU awk oracle") assert.Equal(t, want.stdout, got.stdout, "stdout mismatch against GNU awk oracle") - assert.Equal(t, want.stderr, got.stderr, "stderr mismatch against GNU awk oracle") + if sc.OracleStderrSkip == "" { + assert.Equal(t, want.stderr, got.stderr, "stderr mismatch against GNU awk oracle") + } } }) } diff --git a/tests/scenarios/cmd/awk/basic/command_pipe_ordering.yaml b/tests/scenarios/cmd/awk/basic/command_pipe_ordering.yaml new file mode 100644 index 00000000..f4c54cd3 --- /dev/null +++ b/tests/scenarios/cmd/awk/basic/command_pipe_ordering.yaml @@ -0,0 +1,39 @@ +description: awk output command pipes preserve observable output ordering. +oracle: gawk +input: + script: |+ + awk 'BEGIN { print "b" | "cat"; print "a"; close("cat") }' + awk 'BEGIN { cmd = "cat"; print "b" | cmd; print "a"; close(cmd) }' + awk 'BEGIN { print "x" | "cat"; print "z"; print "y" | "cat" }' + awk 'BEGIN { cmd = "awk \"{ n++ } END { print n }\""; print "a" | cmd; printf ""; print "b" | cmd; close(cmd) }' + awk 'BEGIN { print "b" | "sort"; print "mid"; print "a" | "sort"; close("sort") }' + awk 'BEGIN { for (i = 1; i <= 2; i++) { print i | "cat"; print "x" } close("cat") }' + printf '1\n2\n' | awk '{ print $0 | "cat"; print "x" } END { close("cat") }' + awk 'function f(x) { print x | "sort"; print "s" } BEGIN { f("b"); f("a"); close("sort") }' +expect: + stdout: |+ + a + b + a + b + x + y + z + 2 + mid + a + b + x + x + 1 + 2 + x + x + 1 + 2 + s + s + a + b + stderr: |+ + exit_code: 0 diff --git a/tests/scenarios/cmd/awk/basic/composite_keys_ternary_exit.yaml b/tests/scenarios/cmd/awk/basic/composite_keys_ternary_exit.yaml new file mode 100644 index 00000000..1303fc59 --- /dev/null +++ b/tests/scenarios/cmd/awk/basic/composite_keys_ternary_exit.yaml @@ -0,0 +1,18 @@ +description: awk supports composite array keys, ternary expressions, and exit status. +oracle: gawk +input: + script: |+ + awk 'BEGIN { a=0; b=0; print 0 ? a=2 : b=3; print a,b; print 1 ? a=4 : b=5; print a,b }' + printf 'a x 1\na y 2\na x 3\nb x 4\n' | awk '{ count[$1, $2] += $3; label = ($3 > 2 ? "big" : "small"); classes[$1, label]++ } END { print count["a", "x"], count["a", "y"], count["b", "x"]; print classes["a", "small"], classes["a", "big"]; print count["a" SUBSEP "x"], count["a\034x"], classes["a\034small"]; delete count["a", "x"]; print (("a", "x") in count), (("b", "x") in count), length(SUBSEP); exit 7 }' +expect: + stdout: |+ + 3 + 0 3 + 4 + 4 3 + 4 2 4 + 2 1 + 4 4 2 + 0 1 1 + stderr: |+ + exit_code: 7 diff --git a/tests/scenarios/cmd/awk/basic/environ_numeric_string.yaml b/tests/scenarios/cmd/awk/basic/environ_numeric_string.yaml index 141140d5..a3d649ac 100644 --- a/tests/scenarios/cmd/awk/basic/environ_numeric_string.yaml +++ b/tests/scenarios/cmd/awk/basic/environ_numeric_string.yaml @@ -7,8 +7,10 @@ input: NUMERIC_ENV: "10" script: |+ awk 'BEGIN { print ENVIRON["NUMERIC_ENV"] < 2, ENVIRON["NUMERIC_ENV"] + 0, ENVIRON["NUMERIC_ENV"] == 10 }' + awk 'BEGIN { print (length(ENVIRON) > 0), ("NUMERIC_ENV" in ENVIRON) }' expect: stdout: |+ 0 10 1 + 1 1 stderr: |+ exit_code: 0 diff --git a/tests/scenarios/cmd/awk/basic/match_captures_strtonum_prefix.yaml b/tests/scenarios/cmd/awk/basic/match_captures_strtonum_prefix.yaml new file mode 100644 index 00000000..9c1eb760 --- /dev/null +++ b/tests/scenarios/cmd/awk/basic/match_captures_strtonum_prefix.yaml @@ -0,0 +1,12 @@ +description: awk match capture arrays evaluate arguments before clearing and strtonum parses numeric prefixes. +oracle: gawk +input: + script: |+ + awk 'BEGIN { a[1] = "abc"; print match(a[1], /(b)/, a), RSTART, RLENGTH, a[0], a[1]; print strtonum("123abc"), strtonum("-12.5ms"), strtonum("1e3rows"); print strtonum("012.3"), strtonum("012e2"), strtonum("0128"), strtonum("010") }' +expect: + stdout: |+ + 2 2 1 b b + 123 -12.5 1000 + 12.3 1200 128 8 + stderr: |+ + exit_code: 0 diff --git a/tests/scenarios/cmd/awk/basic/regex_literal_after_return_exit.yaml b/tests/scenarios/cmd/awk/basic/regex_literal_after_return_exit.yaml new file mode 100644 index 00000000..32b4f9dd --- /dev/null +++ b/tests/scenarios/cmd/awk/basic/regex_literal_after_return_exit.yaml @@ -0,0 +1,11 @@ +description: awk accepts regex literals after return and exit. +oracle: gawk +input: + script: |+ + awk 'function f(){ return /x/ } BEGIN { $0 = "x"; print f(); $0 = "z"; print f(); $0 = ""; exit /x/ }' +expect: + stdout: |+ + 1 + 0 + stderr: |+ + exit_code: 0 diff --git a/tests/scenarios/cmd/awk/basic/text_substitution_match.yaml b/tests/scenarios/cmd/awk/basic/text_substitution_match.yaml new file mode 100644 index 00000000..5ade6e6c --- /dev/null +++ b/tests/scenarios/cmd/awk/basic/text_substitution_match.yaml @@ -0,0 +1,17 @@ +description: awk supports sub, gsub, match, RSTART, RLENGTH, and sprintf. +oracle: gawk +input: + script: |+ + awk 'BEGIN { s = "abc123def"; print match(s, /[0-9]+/), RSTART, RLENGTH, substr(s, RSTART, RLENGTH); sub(/[0-9]+/, "<&>", s); print s; gsub(/[a-z]+/, "X", s); print s; print sprintf("%s:%03d", "id", 7) }' + awk 'BEGIN { s = "abc"; gsub(/^/, "X", s); print s; s = "abc"; gsub(/$/, "X", s); print s; s = "abc"; gsub(/^|$/, "X", s); print s }' +expect: + stdout: |+ + 4 4 3 123 + abc<123>def + X<123>X + id:007 + Xabc + abcX + XabcX + stderr: |+ + exit_code: 0 diff --git a/tests/scenarios/cmd/awk/patterns/regex_octal_escape.yaml b/tests/scenarios/cmd/awk/patterns/regex_octal_escape.yaml new file mode 100644 index 00000000..8b0eba64 --- /dev/null +++ b/tests/scenarios/cmd/awk/patterns/regex_octal_escape.yaml @@ -0,0 +1,20 @@ +description: awk regular expression literals support octal byte escapes. +oracle: gawk +input: + script: |+ + printf 'a\n141\n.\nx\n' | awk '/\141/ { print "a", $0 } /\056/ { print "dot", $0 }' + printf '\377\n' | awk '/\377/ { print "byte" }' + printf '\377\n' | awk 'BEGIN { r = "\377" } $0 ~ r { print "dynamic" }' + printf '\303\251\377\n' | awk 'BEGIN { r = "\303\251\377" } $0 ~ r { print "mixed" }' +expect: + stdout: |+ + a a + dot a + dot 141 + dot . + dot x + byte + dynamic + mixed + stderr: |+ + exit_code: 0 diff --git a/tests/scenarios/cmd/awk/safety/print_redirect_rejected.yaml b/tests/scenarios/cmd/awk/safety/print_redirect_rejected.yaml index 03b4dd4f..78a80cb0 100644 --- a/tests/scenarios/cmd/awk/safety/print_redirect_rejected.yaml +++ b/tests/scenarios/cmd/awk/safety/print_redirect_rejected.yaml @@ -12,5 +12,5 @@ input: expect: stdout: "" stderr: |+ - awk: print redirection and command pipes are not supported + awk: print redirection is not supported exit_code: 1 diff --git a/tools/awk-harness/rshell-awk b/tools/awk-harness/rshell-awk index 827198df..12d076fd 100755 --- a/tools/awk-harness/rshell-awk +++ b/tools/awk-harness/rshell-awk @@ -3,7 +3,7 @@ set -euo pipefail RSHELL_BIN="${RSHELL_BIN:-./rshell}" -RSHELL_ALLOWED_PATHS="${RSHELL_ALLOWED_PATHS:-/}" +RSHELL_ALLOWED_PATHS="${RSHELL_ALLOWED_PATHS:-$PWD}" die() { printf '[rshell-awk] error: %s\n' "$*" >&2