package tengo import ( "fmt" "io" "io/ioutil" "path/filepath" "reflect" "strings" "github.com/d5/tengo/v2/parser" "github.com/d5/tengo/v2/token" ) // compilationScope represents a compiled instructions and the last two // instructions that were emitted. type compilationScope struct { Instructions []byte SymbolInit map[string]bool SourceMap map[int]parser.Pos } // loop represents a loop construct that the compiler uses to track the current // loop. type loop struct { Continues []int Breaks []int } // CompilerError represents a compiler error. type CompilerError struct { FileSet *parser.SourceFileSet Node parser.Node Err error } func (e *CompilerError) Error() string { filePos := e.FileSet.Position(e.Node.Pos()) return fmt.Sprintf("Compile Error: %s\n\tat %s", e.Err.Error(), filePos) } // Compiler compiles the AST into a bytecode. type Compiler struct { file *parser.SourceFile parent *Compiler modulePath string constants []Object symbolTable *SymbolTable scopes []compilationScope scopeIndex int modules *ModuleMap compiledModules map[string]*CompiledFunction allowFileImport bool loops []*loop loopIndex int trace io.Writer indent int } // NewCompiler creates a Compiler. func NewCompiler( file *parser.SourceFile, symbolTable *SymbolTable, constants []Object, modules *ModuleMap, trace io.Writer, ) *Compiler { mainScope := compilationScope{ SymbolInit: make(map[string]bool), SourceMap: make(map[int]parser.Pos), } // symbol table if symbolTable == nil { symbolTable = NewSymbolTable() } // add builtin functions to the symbol table for idx, fn := range builtinFuncs { symbolTable.DefineBuiltin(idx, fn.Name) } // builtin modules if modules == nil { modules = NewModuleMap() } return &Compiler{ file: file, symbolTable: symbolTable, constants: constants, scopes: []compilationScope{mainScope}, scopeIndex: 0, loopIndex: -1, trace: trace, modules: modules, compiledModules: make(map[string]*CompiledFunction), } } // Compile compiles the AST node. func (c *Compiler) Compile(node parser.Node) error { if c.trace != nil { if node != nil { defer untracec(tracec(c, fmt.Sprintf("%s (%s)", node.String(), reflect.TypeOf(node).Elem().Name()))) } else { defer untracec(tracec(c, "<nil>")) } } switch node := node.(type) { case *parser.File: for _, stmt := range node.Stmts { if err := c.Compile(stmt); err != nil { return err } } case *parser.ExprStmt: if err := c.Compile(node.Expr); err != nil { return err } c.emit(node, parser.OpPop) case *parser.IncDecStmt: op := token.AddAssign if node.Token == token.Dec { op = token.SubAssign } return c.compileAssign(node, []parser.Expr{node.Expr}, []parser.Expr{&parser.IntLit{Value: 1}}, op) case *parser.ParenExpr: if err := c.Compile(node.Expr); err != nil { return err } case *parser.BinaryExpr: if node.Token == token.LAnd || node.Token == token.LOr { return c.compileLogical(node) } if node.Token == token.Less { if err := c.Compile(node.RHS); err != nil { return err } if err := c.Compile(node.LHS); err != nil { return err } c.emit(node, parser.OpBinaryOp, int(token.Greater)) return nil } else if node.Token == token.LessEq { if err := c.Compile(node.RHS); err != nil { return err } if err := c.Compile(node.LHS); err != nil { return err } c.emit(node, parser.OpBinaryOp, int(token.GreaterEq)) return nil } if err := c.Compile(node.LHS); err != nil { return err } if err := c.Compile(node.RHS); err != nil { return err } switch node.Token { case token.Add: c.emit(node, parser.OpBinaryOp, int(token.Add)) case token.Sub: c.emit(node, parser.OpBinaryOp, int(token.Sub)) case token.Mul: c.emit(node, parser.OpBinaryOp, int(token.Mul)) case token.Quo: c.emit(node, parser.OpBinaryOp, int(token.Quo)) case token.Rem: c.emit(node, parser.OpBinaryOp, int(token.Rem)) case token.Greater: c.emit(node, parser.OpBinaryOp, int(token.Greater)) case token.GreaterEq: c.emit(node, parser.OpBinaryOp, int(token.GreaterEq)) case token.Equal: c.emit(node, parser.OpEqual) case token.NotEqual: c.emit(node, parser.OpNotEqual) case token.And: c.emit(node, parser.OpBinaryOp, int(token.And)) case token.Or: c.emit(node, parser.OpBinaryOp, int(token.Or)) case token.Xor: c.emit(node, parser.OpBinaryOp, int(token.Xor)) case token.AndNot: c.emit(node, parser.OpBinaryOp, int(token.AndNot)) case token.Shl: c.emit(node, parser.OpBinaryOp, int(token.Shl)) case token.Shr: c.emit(node, parser.OpBinaryOp, int(token.Shr)) default: return c.errorf(node, "invalid binary operator: %s", node.Token.String()) } case *parser.IntLit: c.emit(node, parser.OpConstant, c.addConstant(&Int{Value: node.Value})) case *parser.FloatLit: c.emit(node, parser.OpConstant, c.addConstant(&Float{Value: node.Value})) case *parser.BoolLit: if node.Value { c.emit(node, parser.OpTrue) } else { c.emit(node, parser.OpFalse) } case *parser.StringLit: if len(node.Value) > MaxStringLen { return c.error(node, ErrStringLimit) } c.emit(node, parser.OpConstant, c.addConstant(&String{Value: node.Value})) case *parser.CharLit: c.emit(node, parser.OpConstant, c.addConstant(&Char{Value: node.Value})) case *parser.UndefinedLit: c.emit(node, parser.OpNull) case *parser.UnaryExpr: if err := c.Compile(node.Expr); err != nil { return err } switch node.Token { case token.Not: c.emit(node, parser.OpLNot) case token.Sub: c.emit(node, parser.OpMinus) case token.Xor: c.emit(node, parser.OpBComplement) case token.Add: // do nothing? default: return c.errorf(node, "invalid unary operator: %s", node.Token.String()) } case *parser.IfStmt: // open new symbol table for the statement c.symbolTable = c.symbolTable.Fork(true) defer func() { c.symbolTable = c.symbolTable.Parent(false) }() if node.Init != nil { if err := c.Compile(node.Init); err != nil { return err } } if err := c.Compile(node.Cond); err != nil { return err } // first jump placeholder jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0) if err := c.Compile(node.Body); err != nil { return err } if node.Else != nil { // second jump placeholder jumpPos2 := c.emit(node, parser.OpJump, 0) // update first jump offset curPos := len(c.currentInstructions()) c.changeOperand(jumpPos1, curPos) if err := c.Compile(node.Else); err != nil { return err } // update second jump offset curPos = len(c.currentInstructions()) c.changeOperand(jumpPos2, curPos) } else { // update first jump offset curPos := len(c.currentInstructions()) c.changeOperand(jumpPos1, curPos) } case *parser.ForStmt: return c.compileForStmt(node) case *parser.ForInStmt: return c.compileForInStmt(node) case *parser.BranchStmt: if node.Token == token.Break { curLoop := c.currentLoop() if curLoop == nil { return c.errorf(node, "break not allowed outside loop") } pos := c.emit(node, parser.OpJump, 0) curLoop.Breaks = append(curLoop.Breaks, pos) } else if node.Token == token.Continue { curLoop := c.currentLoop() if curLoop == nil { return c.errorf(node, "continue not allowed outside loop") } pos := c.emit(node, parser.OpJump, 0) curLoop.Continues = append(curLoop.Continues, pos) } else { panic(fmt.Errorf("invalid branch statement: %s", node.Token.String())) } case *parser.BlockStmt: if len(node.Stmts) == 0 { return nil } c.symbolTable = c.symbolTable.Fork(true) defer func() { c.symbolTable = c.symbolTable.Parent(false) }() for _, stmt := range node.Stmts { if err := c.Compile(stmt); err != nil { return err } } case *parser.AssignStmt: err := c.compileAssign(node, node.LHS, node.RHS, node.Token) if err != nil { return err } case *parser.Ident: symbol, _, ok := c.symbolTable.Resolve(node.Name) if !ok { return c.errorf(node, "unresolved reference '%s'", node.Name) } switch symbol.Scope { case ScopeGlobal: c.emit(node, parser.OpGetGlobal, symbol.Index) case ScopeLocal: c.emit(node, parser.OpGetLocal, symbol.Index) case ScopeBuiltin: c.emit(node, parser.OpGetBuiltin, symbol.Index) case ScopeFree: c.emit(node, parser.OpGetFree, symbol.Index) } case *parser.ArrayLit: for _, elem := range node.Elements { if err := c.Compile(elem); err != nil { return err } } c.emit(node, parser.OpArray, len(node.Elements)) case *parser.MapLit: for _, elt := range node.Elements { // key if len(elt.Key) > MaxStringLen { return c.error(node, ErrStringLimit) } c.emit(node, parser.OpConstant, c.addConstant(&String{Value: elt.Key})) // value if err := c.Compile(elt.Value); err != nil { return err } } c.emit(node, parser.OpMap, len(node.Elements)*2) case *parser.SelectorExpr: // selector on RHS side if err := c.Compile(node.Expr); err != nil { return err } if err := c.Compile(node.Sel); err != nil { return err } c.emit(node, parser.OpIndex) case *parser.IndexExpr: if err := c.Compile(node.Expr); err != nil { return err } if err := c.Compile(node.Index); err != nil { return err } c.emit(node, parser.OpIndex) case *parser.SliceExpr: if err := c.Compile(node.Expr); err != nil { return err } if node.Low != nil { if err := c.Compile(node.Low); err != nil { return err } } else { c.emit(node, parser.OpNull) } if node.High != nil { if err := c.Compile(node.High); err != nil { return err } } else { c.emit(node, parser.OpNull) } c.emit(node, parser.OpSliceIndex) case *parser.FuncLit: c.enterScope() for _, p := range node.Type.Params.List { s := c.symbolTable.Define(p.Name) // function arguments is not assigned directly. s.LocalAssigned = true } if err := c.Compile(node.Body); err != nil { return err } // code optimization c.optimizeFunc(node) freeSymbols := c.symbolTable.FreeSymbols() numLocals := c.symbolTable.MaxSymbols() instructions, sourceMap := c.leaveScope() for _, s := range freeSymbols { switch s.Scope { case ScopeLocal: if !s.LocalAssigned { // Here, the closure is capturing a local variable that's // not yet assigned its value. One example is a local // recursive function: // // func() { // foo := func(x) { // // .. // return foo(x-1) // } // } // // which translate into // // 0000 GETL 0 // 0002 CLOSURE ? 1 // 0006 DEFL 0 // // . So the local variable (0) is being captured before // it's assigned the value. // // Solution is to transform the code into something like // this: // // func() { // foo := undefined // foo = func(x) { // // .. // return foo(x-1) // } // } // // that is equivalent to // // 0000 NULL // 0001 DEFL 0 // 0003 GETL 0 // 0005 CLOSURE ? 1 // 0009 SETL 0 // c.emit(node, parser.OpNull) c.emit(node, parser.OpDefineLocal, s.Index) s.LocalAssigned = true } c.emit(node, parser.OpGetLocalPtr, s.Index) case ScopeFree: c.emit(node, parser.OpGetFreePtr, s.Index) } } compiledFunction := &CompiledFunction{ Instructions: instructions, NumLocals: numLocals, NumParameters: len(node.Type.Params.List), VarArgs: node.Type.Params.VarArgs, SourceMap: sourceMap, } if len(freeSymbols) > 0 { c.emit(node, parser.OpClosure, c.addConstant(compiledFunction), len(freeSymbols)) } else { c.emit(node, parser.OpConstant, c.addConstant(compiledFunction)) } case *parser.ReturnStmt: if c.symbolTable.Parent(true) == nil { // outside the function return c.errorf(node, "return not allowed outside function") } if node.Result == nil { c.emit(node, parser.OpReturn, 0) } else { if err := c.Compile(node.Result); err != nil { return err } c.emit(node, parser.OpReturn, 1) } case *parser.CallExpr: if err := c.Compile(node.Func); err != nil { return err } for _, arg := range node.Args { if err := c.Compile(arg); err != nil { return err } } c.emit(node, parser.OpCall, len(node.Args)) case *parser.ImportExpr: if node.ModuleName == "" { return c.errorf(node, "empty module name") } if mod := c.modules.Get(node.ModuleName); mod != nil { v, err := mod.Import(node.ModuleName) if err != nil { return err } switch v := v.(type) { case []byte: // module written in Tengo compiled, err := c.compileModule(node, node.ModuleName, node.ModuleName, v) if err != nil { return err } c.emit(node, parser.OpConstant, c.addConstant(compiled)) c.emit(node, parser.OpCall, 0) case Object: // builtin module c.emit(node, parser.OpConstant, c.addConstant(v)) default: panic(fmt.Errorf("invalid import value type: %T", v)) } } else if c.allowFileImport { moduleName := node.ModuleName if !strings.HasSuffix(moduleName, ".tengo") { moduleName += ".tengo" } modulePath, err := filepath.Abs(moduleName) if err != nil { return c.errorf(node, "module file path error: %s", err.Error()) } if err := c.checkCyclicImports(node, modulePath); err != nil { return err } moduleSrc, err := ioutil.ReadFile(moduleName) if err != nil { return c.errorf(node, "module file read error: %s", err.Error()) } compiled, err := c.compileModule(node, moduleName, modulePath, moduleSrc) if err != nil { return err } c.emit(node, parser.OpConstant, c.addConstant(compiled)) c.emit(node, parser.OpCall, 0) } else { return c.errorf(node, "module '%s' not found", node.ModuleName) } case *parser.ExportStmt: // export statement must be in top-level scope if c.scopeIndex != 0 { return c.errorf(node, "export not allowed inside function") } // export statement is simply ignore when compiling non-module code if c.parent == nil { break } if err := c.Compile(node.Result); err != nil { return err } c.emit(node, parser.OpImmutable) c.emit(node, parser.OpReturn, 1) case *parser.ErrorExpr: if err := c.Compile(node.Expr); err != nil { return err } c.emit(node, parser.OpError) case *parser.ImmutableExpr: if err := c.Compile(node.Expr); err != nil { return err } c.emit(node, parser.OpImmutable) case *parser.CondExpr: if err := c.Compile(node.Cond); err != nil { return err } // first jump placeholder jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0) if err := c.Compile(node.True); err != nil { return err } // second jump placeholder jumpPos2 := c.emit(node, parser.OpJump, 0) // update first jump offset curPos := len(c.currentInstructions()) c.changeOperand(jumpPos1, curPos) if err := c.Compile(node.False); err != nil { return err } // update second jump offset curPos = len(c.currentInstructions()) c.changeOperand(jumpPos2, curPos) } return nil } // Bytecode returns a compiled bytecode. func (c *Compiler) Bytecode() *Bytecode { return &Bytecode{ FileSet: c.file.Set(), MainFunction: &CompiledFunction{ Instructions: append(c.currentInstructions(), parser.OpSuspend), SourceMap: c.currentSourceMap(), }, Constants: c.constants, } } // EnableFileImport enables or disables module loading from local files. // Local file modules are disabled by default. func (c *Compiler) EnableFileImport(enable bool) { c.allowFileImport = enable } func (c *Compiler) compileAssign( node parser.Node, lhs, rhs []parser.Expr, op token.Token, ) error { numLHS, numRHS := len(lhs), len(rhs) if numLHS > 1 || numRHS > 1 { return c.errorf(node, "tuple assignment not allowed") } // resolve and compile left-hand side ident, selectors := resolveAssignLHS(lhs[0]) numSel := len(selectors) if op == token.Define && numSel > 0 { // using selector on new variable does not make sense return c.errorf(node, "operator ':=' not allowed with selector") } symbol, depth, exists := c.symbolTable.Resolve(ident) if op == token.Define { if depth == 0 && exists { return c.errorf(node, "'%s' redeclared in this block", ident) } symbol = c.symbolTable.Define(ident) } else { if !exists { return c.errorf(node, "unresolved reference '%s'", ident) } } // +=, -=, *=, /= if op != token.Assign && op != token.Define { if err := c.Compile(lhs[0]); err != nil { return err } } // compile RHSs for _, expr := range rhs { if err := c.Compile(expr); err != nil { return err } } switch op { case token.AddAssign: c.emit(node, parser.OpBinaryOp, int(token.Add)) case token.SubAssign: c.emit(node, parser.OpBinaryOp, int(token.Sub)) case token.MulAssign: c.emit(node, parser.OpBinaryOp, int(token.Mul)) case token.QuoAssign: c.emit(node, parser.OpBinaryOp, int(token.Quo)) case token.RemAssign: c.emit(node, parser.OpBinaryOp, int(token.Rem)) case token.AndAssign: c.emit(node, parser.OpBinaryOp, int(token.And)) case token.OrAssign: c.emit(node, parser.OpBinaryOp, int(token.Or)) case token.AndNotAssign: c.emit(node, parser.OpBinaryOp, int(token.AndNot)) case token.XorAssign: c.emit(node, parser.OpBinaryOp, int(token.Xor)) case token.ShlAssign: c.emit(node, parser.OpBinaryOp, int(token.Shl)) case token.ShrAssign: c.emit(node, parser.OpBinaryOp, int(token.Shr)) } // compile selector expressions (right to left) for i := numSel - 1; i >= 0; i-- { if err := c.Compile(selectors[i]); err != nil { return err } } switch symbol.Scope { case ScopeGlobal: if numSel > 0 { c.emit(node, parser.OpSetSelGlobal, symbol.Index, numSel) } else { c.emit(node, parser.OpSetGlobal, symbol.Index) } case ScopeLocal: if numSel > 0 { c.emit(node, parser.OpSetSelLocal, symbol.Index, numSel) } else { if op == token.Define && !symbol.LocalAssigned { c.emit(node, parser.OpDefineLocal, symbol.Index) } else { c.emit(node, parser.OpSetLocal, symbol.Index) } } // mark the symbol as local-assigned symbol.LocalAssigned = true case ScopeFree: if numSel > 0 { c.emit(node, parser.OpSetSelFree, symbol.Index, numSel) } else { c.emit(node, parser.OpSetFree, symbol.Index) } default: panic(fmt.Errorf("invalid assignment variable scope: %s", symbol.Scope)) } return nil } func (c *Compiler) compileLogical(node *parser.BinaryExpr) error { // left side term if err := c.Compile(node.LHS); err != nil { return err } // jump position var jumpPos int if node.Token == token.LAnd { jumpPos = c.emit(node, parser.OpAndJump, 0) } else { jumpPos = c.emit(node, parser.OpOrJump, 0) } // right side term if err := c.Compile(node.RHS); err != nil { return err } c.changeOperand(jumpPos, len(c.currentInstructions())) return nil } func (c *Compiler) compileForStmt(stmt *parser.ForStmt) error { c.symbolTable = c.symbolTable.Fork(true) defer func() { c.symbolTable = c.symbolTable.Parent(false) }() // init statement if stmt.Init != nil { if err := c.Compile(stmt.Init); err != nil { return err } } // pre-condition position preCondPos := len(c.currentInstructions()) // condition expression postCondPos := -1 if stmt.Cond != nil { if err := c.Compile(stmt.Cond); err != nil { return err } // condition jump position postCondPos = c.emit(stmt, parser.OpJumpFalsy, 0) } // enter loop loop := c.enterLoop() // body statement if err := c.Compile(stmt.Body); err != nil { c.leaveLoop() return err } c.leaveLoop() // post-body position postBodyPos := len(c.currentInstructions()) // post statement if stmt.Post != nil { if err := c.Compile(stmt.Post); err != nil { return err } } // back to condition c.emit(stmt, parser.OpJump, preCondPos) // post-statement position postStmtPos := len(c.currentInstructions()) if postCondPos >= 0 { c.changeOperand(postCondPos, postStmtPos) } // update all break/continue jump positions for _, pos := range loop.Breaks { c.changeOperand(pos, postStmtPos) } for _, pos := range loop.Continues { c.changeOperand(pos, postBodyPos) } return nil } func (c *Compiler) compileForInStmt(stmt *parser.ForInStmt) error { c.symbolTable = c.symbolTable.Fork(true) defer func() { c.symbolTable = c.symbolTable.Parent(false) }() // for-in statement is compiled like following: // // for :it := iterator(iterable); :it.next(); { // k, v := :it.get() // DEFINE operator // // ... body ... // } // // ":it" is a local variable but will be conflict with other user variables // because character ":" is not allowed. // init // :it = iterator(iterable) itSymbol := c.symbolTable.Define(":it") if err := c.Compile(stmt.Iterable); err != nil { return err } c.emit(stmt, parser.OpIteratorInit) if itSymbol.Scope == ScopeGlobal { c.emit(stmt, parser.OpSetGlobal, itSymbol.Index) } else { c.emit(stmt, parser.OpDefineLocal, itSymbol.Index) } // pre-condition position preCondPos := len(c.currentInstructions()) // condition // :it.HasMore() if itSymbol.Scope == ScopeGlobal { c.emit(stmt, parser.OpGetGlobal, itSymbol.Index) } else { c.emit(stmt, parser.OpGetLocal, itSymbol.Index) } c.emit(stmt, parser.OpIteratorNext) // condition jump position postCondPos := c.emit(stmt, parser.OpJumpFalsy, 0) // enter loop loop := c.enterLoop() // assign key variable if stmt.Key.Name != "_" { keySymbol := c.symbolTable.Define(stmt.Key.Name) if itSymbol.Scope == ScopeGlobal { c.emit(stmt, parser.OpGetGlobal, itSymbol.Index) } else { c.emit(stmt, parser.OpGetLocal, itSymbol.Index) } c.emit(stmt, parser.OpIteratorKey) if keySymbol.Scope == ScopeGlobal { c.emit(stmt, parser.OpSetGlobal, keySymbol.Index) } else { c.emit(stmt, parser.OpDefineLocal, keySymbol.Index) } } // assign value variable if stmt.Value.Name != "_" { valueSymbol := c.symbolTable.Define(stmt.Value.Name) if itSymbol.Scope == ScopeGlobal { c.emit(stmt, parser.OpGetGlobal, itSymbol.Index) } else { c.emit(stmt, parser.OpGetLocal, itSymbol.Index) } c.emit(stmt, parser.OpIteratorValue) if valueSymbol.Scope == ScopeGlobal { c.emit(stmt, parser.OpSetGlobal, valueSymbol.Index) } else { c.emit(stmt, parser.OpDefineLocal, valueSymbol.Index) } } // body statement if err := c.Compile(stmt.Body); err != nil { c.leaveLoop() return err } c.leaveLoop() // post-body position postBodyPos := len(c.currentInstructions()) // back to condition c.emit(stmt, parser.OpJump, preCondPos) // post-statement position postStmtPos := len(c.currentInstructions()) c.changeOperand(postCondPos, postStmtPos) // update all break/continue jump positions for _, pos := range loop.Breaks { c.changeOperand(pos, postStmtPos) } for _, pos := range loop.Continues { c.changeOperand(pos, postBodyPos) } return nil } func (c *Compiler) checkCyclicImports( node parser.Node, modulePath string, ) error { if c.modulePath == modulePath { return c.errorf(node, "cyclic module import: %s", modulePath) } else if c.parent != nil { return c.parent.checkCyclicImports(node, modulePath) } return nil } func (c *Compiler) compileModule( node parser.Node, moduleName, modulePath string, src []byte, ) (*CompiledFunction, error) { if err := c.checkCyclicImports(node, modulePath); err != nil { return nil, err } compiledModule, exists := c.loadCompiledModule(modulePath) if exists { return compiledModule, nil } modFile := c.file.Set().AddFile(moduleName, -1, len(src)) p := parser.NewParser(modFile, src, nil) file, err := p.ParseFile() if err != nil { return nil, err } // inherit builtin functions symbolTable := NewSymbolTable() for _, sym := range c.symbolTable.BuiltinSymbols() { symbolTable.DefineBuiltin(sym.Index, sym.Name) } // no global scope for the module symbolTable = symbolTable.Fork(false) // compile module moduleCompiler := c.fork(modFile, modulePath, symbolTable) if err := moduleCompiler.Compile(file); err != nil { return nil, err } // code optimization moduleCompiler.optimizeFunc(node) compiledFunc := moduleCompiler.Bytecode().MainFunction compiledFunc.NumLocals = symbolTable.MaxSymbols() c.storeCompiledModule(modulePath, compiledFunc) return compiledFunc, nil } func (c *Compiler) loadCompiledModule( modulePath string, ) (mod *CompiledFunction, ok bool) { if c.parent != nil { return c.parent.loadCompiledModule(modulePath) } mod, ok = c.compiledModules[modulePath] return } func (c *Compiler) storeCompiledModule( modulePath string, module *CompiledFunction, ) { if c.parent != nil { c.parent.storeCompiledModule(modulePath, module) } c.compiledModules[modulePath] = module } func (c *Compiler) enterLoop() *loop { loop := &loop{} c.loops = append(c.loops, loop) c.loopIndex++ if c.trace != nil { c.printTrace("LOOPE", c.loopIndex) } return loop } func (c *Compiler) leaveLoop() { if c.trace != nil { c.printTrace("LOOPL", c.loopIndex) } c.loops = c.loops[:len(c.loops)-1] c.loopIndex-- } func (c *Compiler) currentLoop() *loop { if c.loopIndex >= 0 { return c.loops[c.loopIndex] } return nil } func (c *Compiler) currentInstructions() []byte { return c.scopes[c.scopeIndex].Instructions } func (c *Compiler) currentSourceMap() map[int]parser.Pos { return c.scopes[c.scopeIndex].SourceMap } func (c *Compiler) enterScope() { scope := compilationScope{ SymbolInit: make(map[string]bool), SourceMap: make(map[int]parser.Pos), } c.scopes = append(c.scopes, scope) c.scopeIndex++ c.symbolTable = c.symbolTable.Fork(false) if c.trace != nil { c.printTrace("SCOPE", c.scopeIndex) } } func (c *Compiler) leaveScope() ( instructions []byte, sourceMap map[int]parser.Pos, ) { instructions = c.currentInstructions() sourceMap = c.currentSourceMap() c.scopes = c.scopes[:len(c.scopes)-1] c.scopeIndex-- c.symbolTable = c.symbolTable.Parent(true) if c.trace != nil { c.printTrace("SCOPL", c.scopeIndex) } return } func (c *Compiler) fork( file *parser.SourceFile, modulePath string, symbolTable *SymbolTable, ) *Compiler { child := NewCompiler(file, symbolTable, nil, c.modules, c.trace) child.modulePath = modulePath // module file path child.parent = c // parent to set to current compiler return child } func (c *Compiler) error(node parser.Node, err error) error { return &CompilerError{ FileSet: c.file.Set(), Node: node, Err: err, } } func (c *Compiler) errorf( node parser.Node, format string, args ...interface{}, ) error { return &CompilerError{ FileSet: c.file.Set(), Node: node, Err: fmt.Errorf(format, args...), } } func (c *Compiler) addConstant(o Object) int { if c.parent != nil { // module compilers will use their parent's constants array return c.parent.addConstant(o) } c.constants = append(c.constants, o) if c.trace != nil { c.printTrace(fmt.Sprintf("CONST %04d %s", len(c.constants)-1, o)) } return len(c.constants) - 1 } func (c *Compiler) addInstruction(b []byte) int { posNewIns := len(c.currentInstructions()) c.scopes[c.scopeIndex].Instructions = append( c.currentInstructions(), b...) return posNewIns } func (c *Compiler) replaceInstruction(pos int, inst []byte) { copy(c.currentInstructions()[pos:], inst) if c.trace != nil { c.printTrace(fmt.Sprintf("REPLC %s", FormatInstructions( c.scopes[c.scopeIndex].Instructions[pos:], pos)[0])) } } func (c *Compiler) changeOperand(opPos int, operand ...int) { op := c.currentInstructions()[opPos] inst := MakeInstruction(op, operand...) c.replaceInstruction(opPos, inst) } // optimizeFunc performs some code-level optimization for the current function // instructions. It also removes unreachable (dead code) instructions and adds // "returns" instruction if needed. func (c *Compiler) optimizeFunc(node parser.Node) { // any instructions between RETURN and the function end // or instructions between RETURN and jump target position // are considered as unreachable. // pass 1. identify all jump destinations dsts := make(map[int]bool) iterateInstructions(c.scopes[c.scopeIndex].Instructions, func(pos int, opcode parser.Opcode, operands []int) bool { switch opcode { case parser.OpJump, parser.OpJumpFalsy, parser.OpAndJump, parser.OpOrJump: dsts[operands[0]] = true } return true }) // pass 2. eliminate dead code var newInsts []byte posMap := make(map[int]int) // old position to new position var dstIdx int var deadCode bool iterateInstructions(c.scopes[c.scopeIndex].Instructions, func(pos int, opcode parser.Opcode, operands []int) bool { switch { case opcode == parser.OpReturn: if deadCode { return true } deadCode = true case dsts[pos]: dstIdx++ deadCode = false case deadCode: return true } posMap[pos] = len(newInsts) newInsts = append(newInsts, MakeInstruction(opcode, operands...)...) return true }) // pass 3. update jump positions var lastOp parser.Opcode var appendReturn bool endPos := len(c.scopes[c.scopeIndex].Instructions) iterateInstructions(newInsts, func(pos int, opcode parser.Opcode, operands []int) bool { switch opcode { case parser.OpJump, parser.OpJumpFalsy, parser.OpAndJump, parser.OpOrJump: newDst, ok := posMap[operands[0]] if ok { copy(newInsts[pos:], MakeInstruction(opcode, newDst)) } else if endPos == operands[0] { // there's a jump instruction that jumps to the end of // function compiler should append "return". appendReturn = true } else { panic(fmt.Errorf("invalid jump position: %d", newDst)) } } lastOp = opcode return true }) if lastOp != parser.OpReturn { appendReturn = true } // pass 4. update source map newSourceMap := make(map[int]parser.Pos) for pos, srcPos := range c.scopes[c.scopeIndex].SourceMap { newPos, ok := posMap[pos] if ok { newSourceMap[newPos] = srcPos } } c.scopes[c.scopeIndex].Instructions = newInsts c.scopes[c.scopeIndex].SourceMap = newSourceMap // append "return" if appendReturn { c.emit(node, parser.OpReturn, 0) } } func (c *Compiler) emit( node parser.Node, opcode parser.Opcode, operands ...int, ) int { filePos := parser.NoPos if node != nil { filePos = node.Pos() } inst := MakeInstruction(opcode, operands...) pos := c.addInstruction(inst) c.scopes[c.scopeIndex].SourceMap[pos] = filePos if c.trace != nil { c.printTrace(fmt.Sprintf("EMIT %s", FormatInstructions( c.scopes[c.scopeIndex].Instructions[pos:], pos)[0])) } return pos } func (c *Compiler) printTrace(a ...interface{}) { const ( dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . " n = len(dots) ) i := 2 * c.indent for i > n { _, _ = fmt.Fprint(c.trace, dots) i -= n } _, _ = fmt.Fprint(c.trace, dots[0:i]) _, _ = fmt.Fprintln(c.trace, a...) } func resolveAssignLHS( expr parser.Expr, ) (name string, selectors []parser.Expr) { switch term := expr.(type) { case *parser.SelectorExpr: name, selectors = resolveAssignLHS(term.Expr) selectors = append(selectors, term.Sel) return case *parser.IndexExpr: name, selectors = resolveAssignLHS(term.Expr) selectors = append(selectors, term.Index) case *parser.Ident: name = term.Name } return } func iterateInstructions( b []byte, fn func(pos int, opcode parser.Opcode, operands []int) bool, ) { for i := 0; i < len(b); i++ { numOperands := parser.OpcodeOperands[b[i]] operands, read := parser.ReadOperands(numOperands, b[i+1:]) if !fn(i, b[i], operands) { break } i += read } } func tracec(c *Compiler, msg string) *Compiler { c.printTrace(msg, "{") c.indent++ return c } func untracec(c *Compiler) { c.indent-- c.printTrace("}") }