diff options
Diffstat (limited to 'vendor/github.com/d5/tengo/compiler/compiler.go')
-rw-r--r-- | vendor/github.com/d5/tengo/compiler/compiler.go | 268 |
1 files changed, 183 insertions, 85 deletions
diff --git a/vendor/github.com/d5/tengo/compiler/compiler.go b/vendor/github.com/d5/tengo/compiler/compiler.go index d8bc05fd..4a3ec3ad 100644 --- a/vendor/github.com/d5/tengo/compiler/compiler.go +++ b/vendor/github.com/d5/tengo/compiler/compiler.go @@ -3,7 +3,10 @@ package compiler import ( "fmt" "io" + "io/ioutil" + "path/filepath" "reflect" + "strings" "github.com/d5/tengo" "github.com/d5/tengo/compiler/ast" @@ -16,14 +19,14 @@ import ( type Compiler struct { file *source.File parent *Compiler - moduleName string + modulePath string constants []objects.Object symbolTable *SymbolTable scopes []CompilationScope scopeIndex int - moduleLoader ModuleLoader - builtinModules map[string]bool + modules *objects.ModuleMap compiledModules map[string]*objects.CompiledFunction + allowFileImport bool loops []*Loop loopIndex int trace io.Writer @@ -31,12 +34,7 @@ type Compiler struct { } // NewCompiler creates a Compiler. -// User can optionally provide the symbol table if one wants to add or remove -// some global- or builtin- scope symbols. If not (nil), Compile will create -// a new symbol table and use the default builtin functions. Likewise, standard -// modules can be explicitly provided if user wants to add or remove some modules. -// By default, Compile will use all the standard modules otherwise. -func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, builtinModules map[string]bool, trace io.Writer) *Compiler { +func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, modules *objects.ModuleMap, trace io.Writer) *Compiler { mainScope := CompilationScope{ symbolInit: make(map[string]bool), sourceMap: make(map[int]source.Pos), @@ -45,15 +43,16 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object // symbol table if symbolTable == nil { symbolTable = NewSymbolTable() + } - for idx, fn := range objects.Builtins { - symbolTable.DefineBuiltin(idx, fn.Name) - } + // add builtin functions to the symbol table + for idx, fn := range objects.Builtins { + symbolTable.DefineBuiltin(idx, fn.Name) } // builtin modules - if builtinModules == nil { - builtinModules = make(map[string]bool) + if modules == nil { + modules = objects.NewModuleMap() } return &Compiler{ @@ -64,7 +63,7 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object scopeIndex: 0, loopIndex: -1, trace: trace, - builtinModules: builtinModules, + modules: modules, compiledModules: make(map[string]*objects.CompiledFunction), } } @@ -120,7 +119,7 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - c.emit(node, OpGreaterThan) + c.emit(node, OpBinaryOp, int(token.Greater)) return nil } else if node.Token == token.LessEq { @@ -131,7 +130,7 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - c.emit(node, OpGreaterThanEqual) + c.emit(node, OpBinaryOp, int(token.GreaterEq)) return nil } @@ -145,35 +144,35 @@ func (c *Compiler) Compile(node ast.Node) error { switch node.Token { case token.Add: - c.emit(node, OpAdd) + c.emit(node, OpBinaryOp, int(token.Add)) case token.Sub: - c.emit(node, OpSub) + c.emit(node, OpBinaryOp, int(token.Sub)) case token.Mul: - c.emit(node, OpMul) + c.emit(node, OpBinaryOp, int(token.Mul)) case token.Quo: - c.emit(node, OpDiv) + c.emit(node, OpBinaryOp, int(token.Quo)) case token.Rem: - c.emit(node, OpRem) + c.emit(node, OpBinaryOp, int(token.Rem)) case token.Greater: - c.emit(node, OpGreaterThan) + c.emit(node, OpBinaryOp, int(token.Greater)) case token.GreaterEq: - c.emit(node, OpGreaterThanEqual) + c.emit(node, OpBinaryOp, int(token.GreaterEq)) case token.Equal: c.emit(node, OpEqual) case token.NotEqual: c.emit(node, OpNotEqual) case token.And: - c.emit(node, OpBAnd) + c.emit(node, OpBinaryOp, int(token.And)) case token.Or: - c.emit(node, OpBOr) + c.emit(node, OpBinaryOp, int(token.Or)) case token.Xor: - c.emit(node, OpBXor) + c.emit(node, OpBinaryOp, int(token.Xor)) case token.AndNot: - c.emit(node, OpBAndNot) + c.emit(node, OpBinaryOp, int(token.AndNot)) case token.Shl: - c.emit(node, OpBShiftLeft) + c.emit(node, OpBinaryOp, int(token.Shl)) case token.Shr: - c.emit(node, OpBShiftRight) + c.emit(node, OpBinaryOp, int(token.Shr)) default: return c.errorf(node, "invalid binary operator: %s", node.Token.String()) } @@ -293,6 +292,15 @@ func (c *Compiler) Compile(node ast.Node) error { } case *ast.BlockStmt: + if len(node.Stmts) == 0 { + return nil + } + + c.symbolTable = c.symbolTable.Fork(true) + defer func() { + c.symbolTable = c.symbolTable.Parent(false) + }() + for _, stmt := range node.Stmts { if err := c.Compile(stmt); err != nil { return err @@ -405,10 +413,8 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - // add OpReturn if function returns nothing - if !c.lastInstructionIs(OpReturnValue) && !c.lastInstructionIs(OpReturn) { - c.emit(node, OpReturn) - } + // code optimization + c.optimizeFunc(node) freeSymbols := c.symbolTable.FreeSymbols() numLocals := c.symbolTable.MaxSymbols() @@ -461,9 +467,9 @@ func (c *Compiler) Compile(node ast.Node) error { s.LocalAssigned = true } - c.emit(node, OpGetLocal, s.Index) + c.emit(node, OpGetLocalPtr, s.Index) case ScopeFree: - c.emit(node, OpGetFree, s.Index) + c.emit(node, OpGetFreePtr, s.Index) } } @@ -487,13 +493,13 @@ func (c *Compiler) Compile(node ast.Node) error { } if node.Result == nil { - c.emit(node, OpReturn) + c.emit(node, OpReturn, 0) } else { if err := c.Compile(node.Result); err != nil { return err } - c.emit(node, OpReturnValue) + c.emit(node, OpReturn, 1) } case *ast.CallExpr: @@ -510,21 +516,57 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(node, OpCall, len(node.Args)) case *ast.ImportExpr: - if c.builtinModules[node.ModuleName] { - if len(node.ModuleName) > tengo.MaxStringLen { - return c.error(node, objects.ErrStringLimit) + if node.ModuleName == "" { + return c.errorf(node, "empty module name") + } + + if mod := c.modules.Get(node.ModuleName); mod != nil { + v, err := mod.Import(node.ModuleName) + if err != nil { + return err } - c.emit(node, OpConstant, c.addConstant(&objects.String{Value: node.ModuleName})) - c.emit(node, OpGetBuiltinModule) - } else { - userMod, err := c.compileModule(node) + switch v := v.(type) { + case []byte: // module written in Tengo + compiled, err := c.compileModule(node, node.ModuleName, node.ModuleName, v) + if err != nil { + return err + } + c.emit(node, OpConstant, c.addConstant(compiled)) + c.emit(node, OpCall, 0) + case objects.Object: // builtin module + c.emit(node, OpConstant, c.addConstant(v)) + default: + panic(fmt.Errorf("invalid import value type: %T", v)) + } + } else if c.allowFileImport { + moduleName := node.ModuleName + if !strings.HasSuffix(moduleName, ".tengo") { + moduleName += ".tengo" + } + + modulePath, err := filepath.Abs(moduleName) if err != nil { + return c.errorf(node, "module file path error: %s", err.Error()) + } + + if err := c.checkCyclicImports(node, modulePath); err != nil { return err } - c.emit(node, OpConstant, c.addConstant(userMod)) + moduleSrc, err := ioutil.ReadFile(moduleName) + if err != nil { + return c.errorf(node, "module file read error: %s", err.Error()) + } + + compiled, err := c.compileModule(node, moduleName, modulePath, moduleSrc) + if err != nil { + return err + } + c.emit(node, OpConstant, c.addConstant(compiled)) c.emit(node, OpCall, 0) + } else { + return c.errorf(node, "module '%s' not found", node.ModuleName) } case *ast.ExportStmt: @@ -543,7 +585,7 @@ func (c *Compiler) Compile(node ast.Node) error { } c.emit(node, OpImmutable) - c.emit(node, OpReturnValue) + c.emit(node, OpReturn, 1) case *ast.ErrorExpr: if err := c.Compile(node.Expr); err != nil { @@ -602,18 +644,16 @@ func (c *Compiler) Bytecode() *Bytecode { } } -// SetModuleLoader sets or replaces the current module loader. -// Note that the module loader is used for user modules, -// not for the standard modules. -func (c *Compiler) SetModuleLoader(moduleLoader ModuleLoader) { - c.moduleLoader = moduleLoader +// EnableFileImport enables or disables module loading from local files. +// Local file modules are disabled by default. +func (c *Compiler) EnableFileImport(enable bool) { + c.allowFileImport = enable } -func (c *Compiler) fork(file *source.File, moduleName string, symbolTable *SymbolTable) *Compiler { - child := NewCompiler(file, symbolTable, nil, c.builtinModules, c.trace) - child.moduleName = moduleName // name of the module to compile - child.parent = c // parent to set to current compiler - child.moduleLoader = c.moduleLoader // share module loader +func (c *Compiler) fork(file *source.File, modulePath string, symbolTable *SymbolTable) *Compiler { + child := NewCompiler(file, symbolTable, nil, c.modules, c.trace) + child.modulePath = modulePath // module file path + child.parent = c // parent to set to current compiler return child } @@ -657,33 +697,6 @@ func (c *Compiler) addInstruction(b []byte) int { return posNewIns } -func (c *Compiler) setLastInstruction(op Opcode, pos int) { - c.scopes[c.scopeIndex].lastInstructions[1] = c.scopes[c.scopeIndex].lastInstructions[0] - - c.scopes[c.scopeIndex].lastInstructions[0].Opcode = op - c.scopes[c.scopeIndex].lastInstructions[0].Position = pos -} - -func (c *Compiler) lastInstructionIs(op Opcode) bool { - if len(c.currentInstructions()) == 0 { - return false - } - - return c.scopes[c.scopeIndex].lastInstructions[0].Opcode == op -} - -func (c *Compiler) removeLastInstruction() { - lastPos := c.scopes[c.scopeIndex].lastInstructions[0].Position - - if c.trace != nil { - c.printTrace(fmt.Sprintf("DELET %s", - FormatInstructions(c.scopes[c.scopeIndex].instructions[lastPos:], lastPos)[0])) - } - - c.scopes[c.scopeIndex].instructions = c.currentInstructions()[:lastPos] - c.scopes[c.scopeIndex].lastInstructions[0] = c.scopes[c.scopeIndex].lastInstructions[1] -} - func (c *Compiler) replaceInstruction(pos int, inst []byte) { copy(c.currentInstructions()[pos:], inst) @@ -700,6 +713,92 @@ func (c *Compiler) changeOperand(opPos int, operand ...int) { c.replaceInstruction(opPos, inst) } +// optimizeFunc performs some code-level optimization for the current function instructions +// it removes unreachable (dead code) instructions and adds "returns" instruction if needed. +func (c *Compiler) optimizeFunc(node ast.Node) { + // any instructions between RETURN and the function end + // or instructions between RETURN and jump target position + // are considered as unreachable. + + // pass 1. identify all jump destinations + dsts := make(map[int]bool) + iterateInstructions(c.scopes[c.scopeIndex].instructions, func(pos int, opcode Opcode, operands []int) bool { + switch opcode { + case OpJump, OpJumpFalsy, OpAndJump, OpOrJump: + dsts[operands[0]] = true + } + + return true + }) + + var newInsts []byte + + // pass 2. eliminate dead code + posMap := make(map[int]int) // old position to new position + var dstIdx int + var deadCode bool + iterateInstructions(c.scopes[c.scopeIndex].instructions, func(pos int, opcode Opcode, operands []int) bool { + switch { + case opcode == OpReturn: + if deadCode { + return true + } + deadCode = true + case dsts[pos]: + dstIdx++ + deadCode = false + case deadCode: + return true + } + + posMap[pos] = len(newInsts) + newInsts = append(newInsts, MakeInstruction(opcode, operands...)...) + return true + }) + + // pass 3. update jump positions + var lastOp Opcode + var appendReturn bool + endPos := len(c.scopes[c.scopeIndex].instructions) + iterateInstructions(newInsts, func(pos int, opcode Opcode, operands []int) bool { + switch opcode { + case OpJump, OpJumpFalsy, OpAndJump, OpOrJump: + newDst, ok := posMap[operands[0]] + if ok { + copy(newInsts[pos:], MakeInstruction(opcode, newDst)) + } else if endPos == operands[0] { + // there's a jump instruction that jumps to the end of function + // compiler should append "return". + appendReturn = true + } else { + panic(fmt.Errorf("invalid jump position: %d", newDst)) + } + } + lastOp = opcode + return true + }) + if lastOp != OpReturn { + appendReturn = true + } + + // pass 4. update source map + newSourceMap := make(map[int]source.Pos) + for pos, srcPos := range c.scopes[c.scopeIndex].sourceMap { + newPos, ok := posMap[pos] + if ok { + newSourceMap[newPos] = srcPos + } + } + + c.scopes[c.scopeIndex].instructions = newInsts + c.scopes[c.scopeIndex].sourceMap = newSourceMap + + // append "return" + if appendReturn { + c.emit(node, OpReturn, 0) + } +} + func (c *Compiler) emit(node ast.Node, opcode Opcode, operands ...int) int { filePos := source.NoPos if node != nil { @@ -709,7 +808,6 @@ func (c *Compiler) emit(node ast.Node, opcode Opcode, operands ...int) int { inst := MakeInstruction(opcode, operands...) pos := c.addInstruction(inst) c.scopes[c.scopeIndex].sourceMap[pos] = filePos - c.setLastInstruction(opcode, pos) if c.trace != nil { c.printTrace(fmt.Sprintf("EMIT %s", |