diff options
Diffstat (limited to 'vendor/github.com/d5/tengo/v2/compiler.go')
-rw-r--r-- | vendor/github.com/d5/tengo/v2/compiler.go | 1312 |
1 files changed, 1312 insertions, 0 deletions
diff --git a/vendor/github.com/d5/tengo/v2/compiler.go b/vendor/github.com/d5/tengo/v2/compiler.go new file mode 100644 index 00000000..eb686ed6 --- /dev/null +++ b/vendor/github.com/d5/tengo/v2/compiler.go @@ -0,0 +1,1312 @@ +package tengo + +import ( + "fmt" + "io" + "io/ioutil" + "path/filepath" + "reflect" + "strings" + + "github.com/d5/tengo/v2/parser" + "github.com/d5/tengo/v2/token" +) + +// compilationScope represents a compiled instructions and the last two +// instructions that were emitted. +type compilationScope struct { + Instructions []byte + SymbolInit map[string]bool + SourceMap map[int]parser.Pos +} + +// loop represents a loop construct that the compiler uses to track the current +// loop. +type loop struct { + Continues []int + Breaks []int +} + +// CompilerError represents a compiler error. +type CompilerError struct { + FileSet *parser.SourceFileSet + Node parser.Node + Err error +} + +func (e *CompilerError) Error() string { + filePos := e.FileSet.Position(e.Node.Pos()) + return fmt.Sprintf("Compile Error: %s\n\tat %s", e.Err.Error(), filePos) +} + +// Compiler compiles the AST into a bytecode. +type Compiler struct { + file *parser.SourceFile + parent *Compiler + modulePath string + constants []Object + symbolTable *SymbolTable + scopes []compilationScope + scopeIndex int + modules *ModuleMap + compiledModules map[string]*CompiledFunction + allowFileImport bool + loops []*loop + loopIndex int + trace io.Writer + indent int +} + +// NewCompiler creates a Compiler. +func NewCompiler( + file *parser.SourceFile, + symbolTable *SymbolTable, + constants []Object, + modules *ModuleMap, + trace io.Writer, +) *Compiler { + mainScope := compilationScope{ + SymbolInit: make(map[string]bool), + SourceMap: make(map[int]parser.Pos), + } + + // symbol table + if symbolTable == nil { + symbolTable = NewSymbolTable() + } + + // add builtin functions to the symbol table + for idx, fn := range builtinFuncs { + symbolTable.DefineBuiltin(idx, fn.Name) + } + + // builtin modules + if modules == nil { + modules = NewModuleMap() + } + + return &Compiler{ + file: file, + symbolTable: symbolTable, + constants: constants, + scopes: []compilationScope{mainScope}, + scopeIndex: 0, + loopIndex: -1, + trace: trace, + modules: modules, + compiledModules: make(map[string]*CompiledFunction), + } +} + +// Compile compiles the AST node. +func (c *Compiler) Compile(node parser.Node) error { + if c.trace != nil { + if node != nil { + defer untracec(tracec(c, fmt.Sprintf("%s (%s)", + node.String(), reflect.TypeOf(node).Elem().Name()))) + } else { + defer untracec(tracec(c, "<nil>")) + } + } + + switch node := node.(type) { + case *parser.File: + for _, stmt := range node.Stmts { + if err := c.Compile(stmt); err != nil { + return err + } + } + case *parser.ExprStmt: + if err := c.Compile(node.Expr); err != nil { + return err + } + c.emit(node, parser.OpPop) + case *parser.IncDecStmt: + op := token.AddAssign + if node.Token == token.Dec { + op = token.SubAssign + } + return c.compileAssign(node, []parser.Expr{node.Expr}, + []parser.Expr{&parser.IntLit{Value: 1}}, op) + case *parser.ParenExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + case *parser.BinaryExpr: + if node.Token == token.LAnd || node.Token == token.LOr { + return c.compileLogical(node) + } + if node.Token == token.Less { + if err := c.Compile(node.RHS); err != nil { + return err + } + if err := c.Compile(node.LHS); err != nil { + return err + } + c.emit(node, parser.OpBinaryOp, int(token.Greater)) + return nil + } else if node.Token == token.LessEq { + if err := c.Compile(node.RHS); err != nil { + return err + } + if err := c.Compile(node.LHS); err != nil { + return err + } + c.emit(node, parser.OpBinaryOp, int(token.GreaterEq)) + return nil + } + if err := c.Compile(node.LHS); err != nil { + return err + } + if err := c.Compile(node.RHS); err != nil { + return err + } + + switch node.Token { + case token.Add: + c.emit(node, parser.OpBinaryOp, int(token.Add)) + case token.Sub: + c.emit(node, parser.OpBinaryOp, int(token.Sub)) + case token.Mul: + c.emit(node, parser.OpBinaryOp, int(token.Mul)) + case token.Quo: + c.emit(node, parser.OpBinaryOp, int(token.Quo)) + case token.Rem: + c.emit(node, parser.OpBinaryOp, int(token.Rem)) + case token.Greater: + c.emit(node, parser.OpBinaryOp, int(token.Greater)) + case token.GreaterEq: + c.emit(node, parser.OpBinaryOp, int(token.GreaterEq)) + case token.Equal: + c.emit(node, parser.OpEqual) + case token.NotEqual: + c.emit(node, parser.OpNotEqual) + case token.And: + c.emit(node, parser.OpBinaryOp, int(token.And)) + case token.Or: + c.emit(node, parser.OpBinaryOp, int(token.Or)) + case token.Xor: + c.emit(node, parser.OpBinaryOp, int(token.Xor)) + case token.AndNot: + c.emit(node, parser.OpBinaryOp, int(token.AndNot)) + case token.Shl: + c.emit(node, parser.OpBinaryOp, int(token.Shl)) + case token.Shr: + c.emit(node, parser.OpBinaryOp, int(token.Shr)) + default: + return c.errorf(node, "invalid binary operator: %s", + node.Token.String()) + } + case *parser.IntLit: + c.emit(node, parser.OpConstant, + c.addConstant(&Int{Value: node.Value})) + case *parser.FloatLit: + c.emit(node, parser.OpConstant, + c.addConstant(&Float{Value: node.Value})) + case *parser.BoolLit: + if node.Value { + c.emit(node, parser.OpTrue) + } else { + c.emit(node, parser.OpFalse) + } + case *parser.StringLit: + if len(node.Value) > MaxStringLen { + return c.error(node, ErrStringLimit) + } + c.emit(node, parser.OpConstant, + c.addConstant(&String{Value: node.Value})) + case *parser.CharLit: + c.emit(node, parser.OpConstant, + c.addConstant(&Char{Value: node.Value})) + case *parser.UndefinedLit: + c.emit(node, parser.OpNull) + case *parser.UnaryExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + + switch node.Token { + case token.Not: + c.emit(node, parser.OpLNot) + case token.Sub: + c.emit(node, parser.OpMinus) + case token.Xor: + c.emit(node, parser.OpBComplement) + case token.Add: + // do nothing? + default: + return c.errorf(node, + "invalid unary operator: %s", node.Token.String()) + } + case *parser.IfStmt: + // open new symbol table for the statement + c.symbolTable = c.symbolTable.Fork(true) + defer func() { + c.symbolTable = c.symbolTable.Parent(false) + }() + + if node.Init != nil { + if err := c.Compile(node.Init); err != nil { + return err + } + } + if err := c.Compile(node.Cond); err != nil { + return err + } + + // first jump placeholder + jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0) + if err := c.Compile(node.Body); err != nil { + return err + } + if node.Else != nil { + // second jump placeholder + jumpPos2 := c.emit(node, parser.OpJump, 0) + + // update first jump offset + curPos := len(c.currentInstructions()) + c.changeOperand(jumpPos1, curPos) + if err := c.Compile(node.Else); err != nil { + return err + } + + // update second jump offset + curPos = len(c.currentInstructions()) + c.changeOperand(jumpPos2, curPos) + } else { + // update first jump offset + curPos := len(c.currentInstructions()) + c.changeOperand(jumpPos1, curPos) + } + case *parser.ForStmt: + return c.compileForStmt(node) + case *parser.ForInStmt: + return c.compileForInStmt(node) + case *parser.BranchStmt: + if node.Token == token.Break { + curLoop := c.currentLoop() + if curLoop == nil { + return c.errorf(node, "break not allowed outside loop") + } + pos := c.emit(node, parser.OpJump, 0) + curLoop.Breaks = append(curLoop.Breaks, pos) + } else if node.Token == token.Continue { + curLoop := c.currentLoop() + if curLoop == nil { + return c.errorf(node, "continue not allowed outside loop") + } + pos := c.emit(node, parser.OpJump, 0) + curLoop.Continues = append(curLoop.Continues, pos) + } else { + panic(fmt.Errorf("invalid branch statement: %s", + node.Token.String())) + } + case *parser.BlockStmt: + if len(node.Stmts) == 0 { + return nil + } + + c.symbolTable = c.symbolTable.Fork(true) + defer func() { + c.symbolTable = c.symbolTable.Parent(false) + }() + + for _, stmt := range node.Stmts { + if err := c.Compile(stmt); err != nil { + return err + } + } + case *parser.AssignStmt: + err := c.compileAssign(node, node.LHS, node.RHS, node.Token) + if err != nil { + return err + } + case *parser.Ident: + symbol, _, ok := c.symbolTable.Resolve(node.Name) + if !ok { + return c.errorf(node, "unresolved reference '%s'", node.Name) + } + + switch symbol.Scope { + case ScopeGlobal: + c.emit(node, parser.OpGetGlobal, symbol.Index) + case ScopeLocal: + c.emit(node, parser.OpGetLocal, symbol.Index) + case ScopeBuiltin: + c.emit(node, parser.OpGetBuiltin, symbol.Index) + case ScopeFree: + c.emit(node, parser.OpGetFree, symbol.Index) + } + case *parser.ArrayLit: + for _, elem := range node.Elements { + if err := c.Compile(elem); err != nil { + return err + } + } + c.emit(node, parser.OpArray, len(node.Elements)) + case *parser.MapLit: + for _, elt := range node.Elements { + // key + if len(elt.Key) > MaxStringLen { + return c.error(node, ErrStringLimit) + } + c.emit(node, parser.OpConstant, + c.addConstant(&String{Value: elt.Key})) + + // value + if err := c.Compile(elt.Value); err != nil { + return err + } + } + c.emit(node, parser.OpMap, len(node.Elements)*2) + + case *parser.SelectorExpr: // selector on RHS side + if err := c.Compile(node.Expr); err != nil { + return err + } + if err := c.Compile(node.Sel); err != nil { + return err + } + c.emit(node, parser.OpIndex) + case *parser.IndexExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + if err := c.Compile(node.Index); err != nil { + return err + } + c.emit(node, parser.OpIndex) + case *parser.SliceExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + if node.Low != nil { + if err := c.Compile(node.Low); err != nil { + return err + } + } else { + c.emit(node, parser.OpNull) + } + if node.High != nil { + if err := c.Compile(node.High); err != nil { + return err + } + } else { + c.emit(node, parser.OpNull) + } + c.emit(node, parser.OpSliceIndex) + case *parser.FuncLit: + c.enterScope() + + for _, p := range node.Type.Params.List { + s := c.symbolTable.Define(p.Name) + + // function arguments is not assigned directly. + s.LocalAssigned = true + } + + if err := c.Compile(node.Body); err != nil { + return err + } + + // code optimization + c.optimizeFunc(node) + + freeSymbols := c.symbolTable.FreeSymbols() + numLocals := c.symbolTable.MaxSymbols() + instructions, sourceMap := c.leaveScope() + + for _, s := range freeSymbols { + switch s.Scope { + case ScopeLocal: + if !s.LocalAssigned { + // Here, the closure is capturing a local variable that's + // not yet assigned its value. One example is a local + // recursive function: + // + // func() { + // foo := func(x) { + // // .. + // return foo(x-1) + // } + // } + // + // which translate into + // + // 0000 GETL 0 + // 0002 CLOSURE ? 1 + // 0006 DEFL 0 + // + // . So the local variable (0) is being captured before + // it's assigned the value. + // + // Solution is to transform the code into something like + // this: + // + // func() { + // foo := undefined + // foo = func(x) { + // // .. + // return foo(x-1) + // } + // } + // + // that is equivalent to + // + // 0000 NULL + // 0001 DEFL 0 + // 0003 GETL 0 + // 0005 CLOSURE ? 1 + // 0009 SETL 0 + // + c.emit(node, parser.OpNull) + c.emit(node, parser.OpDefineLocal, s.Index) + s.LocalAssigned = true + } + c.emit(node, parser.OpGetLocalPtr, s.Index) + case ScopeFree: + c.emit(node, parser.OpGetFreePtr, s.Index) + } + } + + compiledFunction := &CompiledFunction{ + Instructions: instructions, + NumLocals: numLocals, + NumParameters: len(node.Type.Params.List), + VarArgs: node.Type.Params.VarArgs, + SourceMap: sourceMap, + } + if len(freeSymbols) > 0 { + c.emit(node, parser.OpClosure, + c.addConstant(compiledFunction), len(freeSymbols)) + } else { + c.emit(node, parser.OpConstant, c.addConstant(compiledFunction)) + } + case *parser.ReturnStmt: + if c.symbolTable.Parent(true) == nil { + // outside the function + return c.errorf(node, "return not allowed outside function") + } + + if node.Result == nil { + c.emit(node, parser.OpReturn, 0) + } else { + if err := c.Compile(node.Result); err != nil { + return err + } + c.emit(node, parser.OpReturn, 1) + } + case *parser.CallExpr: + if err := c.Compile(node.Func); err != nil { + return err + } + for _, arg := range node.Args { + if err := c.Compile(arg); err != nil { + return err + } + } + c.emit(node, parser.OpCall, len(node.Args)) + case *parser.ImportExpr: + if node.ModuleName == "" { + return c.errorf(node, "empty module name") + } + + if mod := c.modules.Get(node.ModuleName); mod != nil { + v, err := mod.Import(node.ModuleName) + if err != nil { + return err + } + + switch v := v.(type) { + case []byte: // module written in Tengo + compiled, err := c.compileModule(node, + node.ModuleName, node.ModuleName, v) + if err != nil { + return err + } + c.emit(node, parser.OpConstant, c.addConstant(compiled)) + c.emit(node, parser.OpCall, 0) + case Object: // builtin module + c.emit(node, parser.OpConstant, c.addConstant(v)) + default: + panic(fmt.Errorf("invalid import value type: %T", v)) + } + } else if c.allowFileImport { + moduleName := node.ModuleName + if !strings.HasSuffix(moduleName, ".tengo") { + moduleName += ".tengo" + } + + modulePath, err := filepath.Abs(moduleName) + if err != nil { + return c.errorf(node, "module file path error: %s", + err.Error()) + } + + if err := c.checkCyclicImports(node, modulePath); err != nil { + return err + } + + moduleSrc, err := ioutil.ReadFile(moduleName) + if err != nil { + return c.errorf(node, "module file read error: %s", + err.Error()) + } + + compiled, err := c.compileModule(node, + moduleName, modulePath, moduleSrc) + if err != nil { + return err + } + c.emit(node, parser.OpConstant, c.addConstant(compiled)) + c.emit(node, parser.OpCall, 0) + } else { + return c.errorf(node, "module '%s' not found", node.ModuleName) + } + case *parser.ExportStmt: + // export statement must be in top-level scope + if c.scopeIndex != 0 { + return c.errorf(node, "export not allowed inside function") + } + + // export statement is simply ignore when compiling non-module code + if c.parent == nil { + break + } + if err := c.Compile(node.Result); err != nil { + return err + } + c.emit(node, parser.OpImmutable) + c.emit(node, parser.OpReturn, 1) + case *parser.ErrorExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + c.emit(node, parser.OpError) + case *parser.ImmutableExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + c.emit(node, parser.OpImmutable) + case *parser.CondExpr: + if err := c.Compile(node.Cond); err != nil { + return err + } + + // first jump placeholder + jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0) + if err := c.Compile(node.True); err != nil { + return err + } + + // second jump placeholder + jumpPos2 := c.emit(node, parser.OpJump, 0) + + // update first jump offset + curPos := len(c.currentInstructions()) + c.changeOperand(jumpPos1, curPos) + if err := c.Compile(node.False); err != nil { + return err + } + + // update second jump offset + curPos = len(c.currentInstructions()) + c.changeOperand(jumpPos2, curPos) + } + return nil +} + +// Bytecode returns a compiled bytecode. +func (c *Compiler) Bytecode() *Bytecode { + return &Bytecode{ + FileSet: c.file.Set(), + MainFunction: &CompiledFunction{ + Instructions: append(c.currentInstructions(), parser.OpSuspend), + SourceMap: c.currentSourceMap(), + }, + Constants: c.constants, + } +} + +// EnableFileImport enables or disables module loading from local files. +// Local file modules are disabled by default. +func (c *Compiler) EnableFileImport(enable bool) { + c.allowFileImport = enable +} + +func (c *Compiler) compileAssign( + node parser.Node, + lhs, rhs []parser.Expr, + op token.Token, +) error { + numLHS, numRHS := len(lhs), len(rhs) + if numLHS > 1 || numRHS > 1 { + return c.errorf(node, "tuple assignment not allowed") + } + + // resolve and compile left-hand side + ident, selectors := resolveAssignLHS(lhs[0]) + numSel := len(selectors) + + if op == token.Define && numSel > 0 { + // using selector on new variable does not make sense + return c.errorf(node, "operator ':=' not allowed with selector") + } + + symbol, depth, exists := c.symbolTable.Resolve(ident) + if op == token.Define { + if depth == 0 && exists { + return c.errorf(node, "'%s' redeclared in this block", ident) + } + symbol = c.symbolTable.Define(ident) + } else { + if !exists { + return c.errorf(node, "unresolved reference '%s'", ident) + } + } + + // +=, -=, *=, /= + if op != token.Assign && op != token.Define { + if err := c.Compile(lhs[0]); err != nil { + return err + } + } + + // compile RHSs + for _, expr := range rhs { + if err := c.Compile(expr); err != nil { + return err + } + } + + switch op { + case token.AddAssign: + c.emit(node, parser.OpBinaryOp, int(token.Add)) + case token.SubAssign: + c.emit(node, parser.OpBinaryOp, int(token.Sub)) + case token.MulAssign: + c.emit(node, parser.OpBinaryOp, int(token.Mul)) + case token.QuoAssign: + c.emit(node, parser.OpBinaryOp, int(token.Quo)) + case token.RemAssign: + c.emit(node, parser.OpBinaryOp, int(token.Rem)) + case token.AndAssign: + c.emit(node, parser.OpBinaryOp, int(token.And)) + case token.OrAssign: + c.emit(node, parser.OpBinaryOp, int(token.Or)) + case token.AndNotAssign: + c.emit(node, parser.OpBinaryOp, int(token.AndNot)) + case token.XorAssign: + c.emit(node, parser.OpBinaryOp, int(token.Xor)) + case token.ShlAssign: + c.emit(node, parser.OpBinaryOp, int(token.Shl)) + case token.ShrAssign: + c.emit(node, parser.OpBinaryOp, int(token.Shr)) + } + + // compile selector expressions (right to left) + for i := numSel - 1; i >= 0; i-- { + if err := c.Compile(selectors[i]); err != nil { + return err + } + } + + switch symbol.Scope { + case ScopeGlobal: + if numSel > 0 { + c.emit(node, parser.OpSetSelGlobal, symbol.Index, numSel) + } else { + c.emit(node, parser.OpSetGlobal, symbol.Index) + } + case ScopeLocal: + if numSel > 0 { + c.emit(node, parser.OpSetSelLocal, symbol.Index, numSel) + } else { + if op == token.Define && !symbol.LocalAssigned { + c.emit(node, parser.OpDefineLocal, symbol.Index) + } else { + c.emit(node, parser.OpSetLocal, symbol.Index) + } + } + + // mark the symbol as local-assigned + symbol.LocalAssigned = true + case ScopeFree: + if numSel > 0 { + c.emit(node, parser.OpSetSelFree, symbol.Index, numSel) + } else { + c.emit(node, parser.OpSetFree, symbol.Index) + } + default: + panic(fmt.Errorf("invalid assignment variable scope: %s", + symbol.Scope)) + } + return nil +} + +func (c *Compiler) compileLogical(node *parser.BinaryExpr) error { + // left side term + if err := c.Compile(node.LHS); err != nil { + return err + } + + // jump position + var jumpPos int + if node.Token == token.LAnd { + jumpPos = c.emit(node, parser.OpAndJump, 0) + } else { + jumpPos = c.emit(node, parser.OpOrJump, 0) + } + + // right side term + if err := c.Compile(node.RHS); err != nil { + return err + } + + c.changeOperand(jumpPos, len(c.currentInstructions())) + return nil +} + +func (c *Compiler) compileForStmt(stmt *parser.ForStmt) error { + c.symbolTable = c.symbolTable.Fork(true) + defer func() { + c.symbolTable = c.symbolTable.Parent(false) + }() + + // init statement + if stmt.Init != nil { + if err := c.Compile(stmt.Init); err != nil { + return err + } + } + + // pre-condition position + preCondPos := len(c.currentInstructions()) + + // condition expression + postCondPos := -1 + if stmt.Cond != nil { + if err := c.Compile(stmt.Cond); err != nil { + return err + } + // condition jump position + postCondPos = c.emit(stmt, parser.OpJumpFalsy, 0) + } + + // enter loop + loop := c.enterLoop() + + // body statement + if err := c.Compile(stmt.Body); err != nil { + c.leaveLoop() + return err + } + + c.leaveLoop() + + // post-body position + postBodyPos := len(c.currentInstructions()) + + // post statement + if stmt.Post != nil { + if err := c.Compile(stmt.Post); err != nil { + return err + } + } + + // back to condition + c.emit(stmt, parser.OpJump, preCondPos) + + // post-statement position + postStmtPos := len(c.currentInstructions()) + if postCondPos >= 0 { + c.changeOperand(postCondPos, postStmtPos) + } + + // update all break/continue jump positions + for _, pos := range loop.Breaks { + c.changeOperand(pos, postStmtPos) + } + for _, pos := range loop.Continues { + c.changeOperand(pos, postBodyPos) + } + return nil +} + +func (c *Compiler) compileForInStmt(stmt *parser.ForInStmt) error { + c.symbolTable = c.symbolTable.Fork(true) + defer func() { + c.symbolTable = c.symbolTable.Parent(false) + }() + + // for-in statement is compiled like following: + // + // for :it := iterator(iterable); :it.next(); { + // k, v := :it.get() // DEFINE operator + // + // ... body ... + // } + // + // ":it" is a local variable but will be conflict with other user variables + // because character ":" is not allowed. + + // init + // :it = iterator(iterable) + itSymbol := c.symbolTable.Define(":it") + if err := c.Compile(stmt.Iterable); err != nil { + return err + } + c.emit(stmt, parser.OpIteratorInit) + if itSymbol.Scope == ScopeGlobal { + c.emit(stmt, parser.OpSetGlobal, itSymbol.Index) + } else { + c.emit(stmt, parser.OpDefineLocal, itSymbol.Index) + } + + // pre-condition position + preCondPos := len(c.currentInstructions()) + + // condition + // :it.HasMore() + if itSymbol.Scope == ScopeGlobal { + c.emit(stmt, parser.OpGetGlobal, itSymbol.Index) + } else { + c.emit(stmt, parser.OpGetLocal, itSymbol.Index) + } + c.emit(stmt, parser.OpIteratorNext) + + // condition jump position + postCondPos := c.emit(stmt, parser.OpJumpFalsy, 0) + + // enter loop + loop := c.enterLoop() + + // assign key variable + if stmt.Key.Name != "_" { + keySymbol := c.symbolTable.Define(stmt.Key.Name) + if itSymbol.Scope == ScopeGlobal { + c.emit(stmt, parser.OpGetGlobal, itSymbol.Index) + } else { + c.emit(stmt, parser.OpGetLocal, itSymbol.Index) + } + c.emit(stmt, parser.OpIteratorKey) + if keySymbol.Scope == ScopeGlobal { + c.emit(stmt, parser.OpSetGlobal, keySymbol.Index) + } else { + c.emit(stmt, parser.OpDefineLocal, keySymbol.Index) + } + } + + // assign value variable + if stmt.Value.Name != "_" { + valueSymbol := c.symbolTable.Define(stmt.Value.Name) + if itSymbol.Scope == ScopeGlobal { + c.emit(stmt, parser.OpGetGlobal, itSymbol.Index) + } else { + c.emit(stmt, parser.OpGetLocal, itSymbol.Index) + } + c.emit(stmt, parser.OpIteratorValue) + if valueSymbol.Scope == ScopeGlobal { + c.emit(stmt, parser.OpSetGlobal, valueSymbol.Index) + } else { + c.emit(stmt, parser.OpDefineLocal, valueSymbol.Index) + } + } + + // body statement + if err := c.Compile(stmt.Body); err != nil { + c.leaveLoop() + return err + } + + c.leaveLoop() + + // post-body position + postBodyPos := len(c.currentInstructions()) + + // back to condition + c.emit(stmt, parser.OpJump, preCondPos) + + // post-statement position + postStmtPos := len(c.currentInstructions()) + c.changeOperand(postCondPos, postStmtPos) + + // update all break/continue jump positions + for _, pos := range loop.Breaks { + c.changeOperand(pos, postStmtPos) + } + for _, pos := range loop.Continues { + c.changeOperand(pos, postBodyPos) + } + return nil +} + +func (c *Compiler) checkCyclicImports( + node parser.Node, + modulePath string, +) error { + if c.modulePath == modulePath { + return c.errorf(node, "cyclic module import: %s", modulePath) + } else if c.parent != nil { + return c.parent.checkCyclicImports(node, modulePath) + } + return nil +} + +func (c *Compiler) compileModule( + node parser.Node, + moduleName, modulePath string, + src []byte, +) (*CompiledFunction, error) { + if err := c.checkCyclicImports(node, modulePath); err != nil { + return nil, err + } + + compiledModule, exists := c.loadCompiledModule(modulePath) + if exists { + return compiledModule, nil + } + + modFile := c.file.Set().AddFile(moduleName, -1, len(src)) + p := parser.NewParser(modFile, src, nil) + file, err := p.ParseFile() + if err != nil { + return nil, err + } + + // inherit builtin functions + symbolTable := NewSymbolTable() + for _, sym := range c.symbolTable.BuiltinSymbols() { + symbolTable.DefineBuiltin(sym.Index, sym.Name) + } + + // no global scope for the module + symbolTable = symbolTable.Fork(false) + + // compile module + moduleCompiler := c.fork(modFile, modulePath, symbolTable) + if err := moduleCompiler.Compile(file); err != nil { + return nil, err + } + + // code optimization + moduleCompiler.optimizeFunc(node) + compiledFunc := moduleCompiler.Bytecode().MainFunction + compiledFunc.NumLocals = symbolTable.MaxSymbols() + c.storeCompiledModule(modulePath, compiledFunc) + return compiledFunc, nil +} + +func (c *Compiler) loadCompiledModule( + modulePath string, +) (mod *CompiledFunction, ok bool) { + if c.parent != nil { + return c.parent.loadCompiledModule(modulePath) + } + mod, ok = c.compiledModules[modulePath] + return +} + +func (c *Compiler) storeCompiledModule( + modulePath string, + module *CompiledFunction, +) { + if c.parent != nil { + c.parent.storeCompiledModule(modulePath, module) + } + c.compiledModules[modulePath] = module +} + +func (c *Compiler) enterLoop() *loop { + loop := &loop{} + c.loops = append(c.loops, loop) + c.loopIndex++ + if c.trace != nil { + c.printTrace("LOOPE", c.loopIndex) + } + return loop +} + +func (c *Compiler) leaveLoop() { + if c.trace != nil { + c.printTrace("LOOPL", c.loopIndex) + } + c.loops = c.loops[:len(c.loops)-1] + c.loopIndex-- +} + +func (c *Compiler) currentLoop() *loop { + if c.loopIndex >= 0 { + return c.loops[c.loopIndex] + } + return nil +} + +func (c *Compiler) currentInstructions() []byte { + return c.scopes[c.scopeIndex].Instructions +} + +func (c *Compiler) currentSourceMap() map[int]parser.Pos { + return c.scopes[c.scopeIndex].SourceMap +} + +func (c *Compiler) enterScope() { + scope := compilationScope{ + SymbolInit: make(map[string]bool), + SourceMap: make(map[int]parser.Pos), + } + c.scopes = append(c.scopes, scope) + c.scopeIndex++ + c.symbolTable = c.symbolTable.Fork(false) + if c.trace != nil { + c.printTrace("SCOPE", c.scopeIndex) + } +} + +func (c *Compiler) leaveScope() ( + instructions []byte, + sourceMap map[int]parser.Pos, +) { + instructions = c.currentInstructions() + sourceMap = c.currentSourceMap() + c.scopes = c.scopes[:len(c.scopes)-1] + c.scopeIndex-- + c.symbolTable = c.symbolTable.Parent(true) + if c.trace != nil { + c.printTrace("SCOPL", c.scopeIndex) + } + return +} + +func (c *Compiler) fork( + file *parser.SourceFile, + modulePath string, + symbolTable *SymbolTable, +) *Compiler { + child := NewCompiler(file, symbolTable, nil, c.modules, c.trace) + child.modulePath = modulePath // module file path + child.parent = c // parent to set to current compiler + return child +} + +func (c *Compiler) error(node parser.Node, err error) error { + return &CompilerError{ + FileSet: c.file.Set(), + Node: node, + Err: err, + } +} + +func (c *Compiler) errorf( + node parser.Node, + format string, + args ...interface{}, +) error { + return &CompilerError{ + FileSet: c.file.Set(), + Node: node, + Err: fmt.Errorf(format, args...), + } +} + +func (c *Compiler) addConstant(o Object) int { + if c.parent != nil { + // module compilers will use their parent's constants array + return c.parent.addConstant(o) + } + c.constants = append(c.constants, o) + if c.trace != nil { + c.printTrace(fmt.Sprintf("CONST %04d %s", len(c.constants)-1, o)) + } + return len(c.constants) - 1 +} + +func (c *Compiler) addInstruction(b []byte) int { + posNewIns := len(c.currentInstructions()) + c.scopes[c.scopeIndex].Instructions = append( + c.currentInstructions(), b...) + return posNewIns +} + +func (c *Compiler) replaceInstruction(pos int, inst []byte) { + copy(c.currentInstructions()[pos:], inst) + if c.trace != nil { + c.printTrace(fmt.Sprintf("REPLC %s", + FormatInstructions( + c.scopes[c.scopeIndex].Instructions[pos:], pos)[0])) + } +} + +func (c *Compiler) changeOperand(opPos int, operand ...int) { + op := c.currentInstructions()[opPos] + inst := MakeInstruction(op, operand...) + c.replaceInstruction(opPos, inst) +} + +// optimizeFunc performs some code-level optimization for the current function +// instructions. It also removes unreachable (dead code) instructions and adds +// "returns" instruction if needed. +func (c *Compiler) optimizeFunc(node parser.Node) { + // any instructions between RETURN and the function end + // or instructions between RETURN and jump target position + // are considered as unreachable. + + // pass 1. identify all jump destinations + dsts := make(map[int]bool) + iterateInstructions(c.scopes[c.scopeIndex].Instructions, + func(pos int, opcode parser.Opcode, operands []int) bool { + switch opcode { + case parser.OpJump, parser.OpJumpFalsy, + parser.OpAndJump, parser.OpOrJump: + dsts[operands[0]] = true + } + return true + }) + + // pass 2. eliminate dead code + var newInsts []byte + posMap := make(map[int]int) // old position to new position + var dstIdx int + var deadCode bool + iterateInstructions(c.scopes[c.scopeIndex].Instructions, + func(pos int, opcode parser.Opcode, operands []int) bool { + switch { + case opcode == parser.OpReturn: + if deadCode { + return true + } + deadCode = true + case dsts[pos]: + dstIdx++ + deadCode = false + case deadCode: + return true + } + posMap[pos] = len(newInsts) + newInsts = append(newInsts, + MakeInstruction(opcode, operands...)...) + return true + }) + + // pass 3. update jump positions + var lastOp parser.Opcode + var appendReturn bool + endPos := len(c.scopes[c.scopeIndex].Instructions) + iterateInstructions(newInsts, + func(pos int, opcode parser.Opcode, operands []int) bool { + switch opcode { + case parser.OpJump, parser.OpJumpFalsy, parser.OpAndJump, + parser.OpOrJump: + newDst, ok := posMap[operands[0]] + if ok { + copy(newInsts[pos:], + MakeInstruction(opcode, newDst)) + } else if endPos == operands[0] { + // there's a jump instruction that jumps to the end of + // function compiler should append "return". + appendReturn = true + } else { + panic(fmt.Errorf("invalid jump position: %d", newDst)) + } + } + lastOp = opcode + return true + }) + if lastOp != parser.OpReturn { + appendReturn = true + } + + // pass 4. update source map + newSourceMap := make(map[int]parser.Pos) + for pos, srcPos := range c.scopes[c.scopeIndex].SourceMap { + newPos, ok := posMap[pos] + if ok { + newSourceMap[newPos] = srcPos + } + } + c.scopes[c.scopeIndex].Instructions = newInsts + c.scopes[c.scopeIndex].SourceMap = newSourceMap + + // append "return" + if appendReturn { + c.emit(node, parser.OpReturn, 0) + } +} + +func (c *Compiler) emit( + node parser.Node, + opcode parser.Opcode, + operands ...int, +) int { + filePos := parser.NoPos + if node != nil { + filePos = node.Pos() + } + + inst := MakeInstruction(opcode, operands...) + pos := c.addInstruction(inst) + c.scopes[c.scopeIndex].SourceMap[pos] = filePos + if c.trace != nil { + c.printTrace(fmt.Sprintf("EMIT %s", + FormatInstructions( + c.scopes[c.scopeIndex].Instructions[pos:], pos)[0])) + } + return pos +} + +func (c *Compiler) printTrace(a ...interface{}) { + const ( + dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . " + n = len(dots) + ) + + i := 2 * c.indent + for i > n { + _, _ = fmt.Fprint(c.trace, dots) + i -= n + } + _, _ = fmt.Fprint(c.trace, dots[0:i]) + _, _ = fmt.Fprintln(c.trace, a...) +} + +func resolveAssignLHS( + expr parser.Expr, +) (name string, selectors []parser.Expr) { + switch term := expr.(type) { + case *parser.SelectorExpr: + name, selectors = resolveAssignLHS(term.Expr) + selectors = append(selectors, term.Sel) + return + case *parser.IndexExpr: + name, selectors = resolveAssignLHS(term.Expr) + selectors = append(selectors, term.Index) + case *parser.Ident: + name = term.Name + } + return +} + +func iterateInstructions( + b []byte, + fn func(pos int, opcode parser.Opcode, operands []int) bool, +) { + for i := 0; i < len(b); i++ { + numOperands := parser.OpcodeOperands[b[i]] + operands, read := parser.ReadOperands(numOperands, b[i+1:]) + if !fn(i, b[i], operands) { + break + } + i += read + } +} + +func tracec(c *Compiler, msg string) *Compiler { + c.printTrace(msg, "{") + c.indent++ + return c +} + +func untracec(c *Compiler) { + c.indent-- + c.printTrace("}") +} |