diff options
Diffstat (limited to 'vendor/github.com/d5/tengo/compiler/compiler.go')
-rw-r--r-- | vendor/github.com/d5/tengo/compiler/compiler.go | 731 |
1 files changed, 731 insertions, 0 deletions
diff --git a/vendor/github.com/d5/tengo/compiler/compiler.go b/vendor/github.com/d5/tengo/compiler/compiler.go new file mode 100644 index 00000000..141ea8fd --- /dev/null +++ b/vendor/github.com/d5/tengo/compiler/compiler.go @@ -0,0 +1,731 @@ +package compiler + +import ( + "fmt" + "io" + "reflect" + + "github.com/d5/tengo/compiler/ast" + "github.com/d5/tengo/compiler/source" + "github.com/d5/tengo/compiler/token" + "github.com/d5/tengo/objects" + "github.com/d5/tengo/stdlib" +) + +// Compiler compiles the AST into a bytecode. +type Compiler struct { + file *source.File + parent *Compiler + moduleName string + constants []objects.Object + symbolTable *SymbolTable + scopes []CompilationScope + scopeIndex int + moduleLoader ModuleLoader + builtinModules map[string]bool + compiledModules map[string]*objects.CompiledFunction + loops []*Loop + loopIndex int + trace io.Writer + indent int +} + +// NewCompiler creates a Compiler. +// User can optionally provide the symbol table if one wants to add or remove +// some global- or builtin- scope symbols. If not (nil), Compile will create +// a new symbol table and use the default builtin functions. Likewise, standard +// modules can be explicitly provided if user wants to add or remove some modules. +// By default, Compile will use all the standard modules otherwise. +func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, builtinModules map[string]bool, trace io.Writer) *Compiler { + mainScope := CompilationScope{ + symbolInit: make(map[string]bool), + sourceMap: make(map[int]source.Pos), + } + + // symbol table + if symbolTable == nil { + symbolTable = NewSymbolTable() + + for idx, fn := range objects.Builtins { + symbolTable.DefineBuiltin(idx, fn.Name) + } + } + + // builtin modules + if builtinModules == nil { + builtinModules = make(map[string]bool) + for name := range stdlib.Modules { + builtinModules[name] = true + } + } + + return &Compiler{ + file: file, + symbolTable: symbolTable, + constants: constants, + scopes: []CompilationScope{mainScope}, + scopeIndex: 0, + loopIndex: -1, + trace: trace, + builtinModules: builtinModules, + compiledModules: make(map[string]*objects.CompiledFunction), + } +} + +// Compile compiles the AST node. +func (c *Compiler) Compile(node ast.Node) error { + if c.trace != nil { + if node != nil { + defer un(trace(c, fmt.Sprintf("%s (%s)", node.String(), reflect.TypeOf(node).Elem().Name()))) + } else { + defer un(trace(c, "<nil>")) + } + } + + switch node := node.(type) { + case *ast.File: + for _, stmt := range node.Stmts { + if err := c.Compile(stmt); err != nil { + return err + } + } + + case *ast.ExprStmt: + if err := c.Compile(node.Expr); err != nil { + return err + } + c.emit(node, OpPop) + + case *ast.IncDecStmt: + op := token.AddAssign + if node.Token == token.Dec { + op = token.SubAssign + } + + return c.compileAssign(node, []ast.Expr{node.Expr}, []ast.Expr{&ast.IntLit{Value: 1}}, op) + + case *ast.ParenExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + + case *ast.BinaryExpr: + if node.Token == token.LAnd || node.Token == token.LOr { + return c.compileLogical(node) + } + + if node.Token == token.Less { + if err := c.Compile(node.RHS); err != nil { + return err + } + + if err := c.Compile(node.LHS); err != nil { + return err + } + + c.emit(node, OpGreaterThan) + + return nil + } else if node.Token == token.LessEq { + if err := c.Compile(node.RHS); err != nil { + return err + } + if err := c.Compile(node.LHS); err != nil { + return err + } + + c.emit(node, OpGreaterThanEqual) + + return nil + } + + if err := c.Compile(node.LHS); err != nil { + return err + } + if err := c.Compile(node.RHS); err != nil { + return err + } + + switch node.Token { + case token.Add: + c.emit(node, OpAdd) + case token.Sub: + c.emit(node, OpSub) + case token.Mul: + c.emit(node, OpMul) + case token.Quo: + c.emit(node, OpDiv) + case token.Rem: + c.emit(node, OpRem) + case token.Greater: + c.emit(node, OpGreaterThan) + case token.GreaterEq: + c.emit(node, OpGreaterThanEqual) + case token.Equal: + c.emit(node, OpEqual) + case token.NotEqual: + c.emit(node, OpNotEqual) + case token.And: + c.emit(node, OpBAnd) + case token.Or: + c.emit(node, OpBOr) + case token.Xor: + c.emit(node, OpBXor) + case token.AndNot: + c.emit(node, OpBAndNot) + case token.Shl: + c.emit(node, OpBShiftLeft) + case token.Shr: + c.emit(node, OpBShiftRight) + default: + return c.errorf(node, "invalid binary operator: %s", node.Token.String()) + } + + case *ast.IntLit: + c.emit(node, OpConstant, c.addConstant(&objects.Int{Value: node.Value})) + + case *ast.FloatLit: + c.emit(node, OpConstant, c.addConstant(&objects.Float{Value: node.Value})) + + case *ast.BoolLit: + if node.Value { + c.emit(node, OpTrue) + } else { + c.emit(node, OpFalse) + } + + case *ast.StringLit: + c.emit(node, OpConstant, c.addConstant(&objects.String{Value: node.Value})) + + case *ast.CharLit: + c.emit(node, OpConstant, c.addConstant(&objects.Char{Value: node.Value})) + + case *ast.UndefinedLit: + c.emit(node, OpNull) + + case *ast.UnaryExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + + switch node.Token { + case token.Not: + c.emit(node, OpLNot) + case token.Sub: + c.emit(node, OpMinus) + case token.Xor: + c.emit(node, OpBComplement) + case token.Add: + // do nothing? + default: + return c.errorf(node, "invalid unary operator: %s", node.Token.String()) + } + + case *ast.IfStmt: + // open new symbol table for the statement + c.symbolTable = c.symbolTable.Fork(true) + defer func() { + c.symbolTable = c.symbolTable.Parent(false) + }() + + if node.Init != nil { + if err := c.Compile(node.Init); err != nil { + return err + } + } + + if err := c.Compile(node.Cond); err != nil { + return err + } + + // first jump placeholder + jumpPos1 := c.emit(node, OpJumpFalsy, 0) + + if err := c.Compile(node.Body); err != nil { + return err + } + + if node.Else != nil { + // second jump placeholder + jumpPos2 := c.emit(node, OpJump, 0) + + // update first jump offset + curPos := len(c.currentInstructions()) + c.changeOperand(jumpPos1, curPos) + + if err := c.Compile(node.Else); err != nil { + return err + } + + // update second jump offset + curPos = len(c.currentInstructions()) + c.changeOperand(jumpPos2, curPos) + } else { + // update first jump offset + curPos := len(c.currentInstructions()) + c.changeOperand(jumpPos1, curPos) + } + + case *ast.ForStmt: + return c.compileForStmt(node) + + case *ast.ForInStmt: + return c.compileForInStmt(node) + + case *ast.BranchStmt: + if node.Token == token.Break { + curLoop := c.currentLoop() + if curLoop == nil { + return c.errorf(node, "break not allowed outside loop") + } + pos := c.emit(node, OpJump, 0) + curLoop.Breaks = append(curLoop.Breaks, pos) + } else if node.Token == token.Continue { + curLoop := c.currentLoop() + if curLoop == nil { + return c.errorf(node, "continue not allowed outside loop") + } + pos := c.emit(node, OpJump, 0) + curLoop.Continues = append(curLoop.Continues, pos) + } else { + panic(fmt.Errorf("invalid branch statement: %s", node.Token.String())) + } + + case *ast.BlockStmt: + for _, stmt := range node.Stmts { + if err := c.Compile(stmt); err != nil { + return err + } + } + + case *ast.AssignStmt: + if err := c.compileAssign(node, node.LHS, node.RHS, node.Token); err != nil { + return err + } + + case *ast.Ident: + symbol, _, ok := c.symbolTable.Resolve(node.Name) + if !ok { + return c.errorf(node, "unresolved reference '%s'", node.Name) + } + + switch symbol.Scope { + case ScopeGlobal: + c.emit(node, OpGetGlobal, symbol.Index) + case ScopeLocal: + c.emit(node, OpGetLocal, symbol.Index) + case ScopeBuiltin: + c.emit(node, OpGetBuiltin, symbol.Index) + case ScopeFree: + c.emit(node, OpGetFree, symbol.Index) + } + + case *ast.ArrayLit: + for _, elem := range node.Elements { + if err := c.Compile(elem); err != nil { + return err + } + } + + c.emit(node, OpArray, len(node.Elements)) + + case *ast.MapLit: + for _, elt := range node.Elements { + // key + c.emit(node, OpConstant, c.addConstant(&objects.String{Value: elt.Key})) + + // value + if err := c.Compile(elt.Value); err != nil { + return err + } + } + + c.emit(node, OpMap, len(node.Elements)*2) + + case *ast.SelectorExpr: // selector on RHS side + if err := c.Compile(node.Expr); err != nil { + return err + } + + if err := c.Compile(node.Sel); err != nil { + return err + } + + c.emit(node, OpIndex) + + case *ast.IndexExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + + if err := c.Compile(node.Index); err != nil { + return err + } + + c.emit(node, OpIndex) + + case *ast.SliceExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + + if node.Low != nil { + if err := c.Compile(node.Low); err != nil { + return err + } + } else { + c.emit(node, OpNull) + } + + if node.High != nil { + if err := c.Compile(node.High); err != nil { + return err + } + } else { + c.emit(node, OpNull) + } + + c.emit(node, OpSliceIndex) + + case *ast.FuncLit: + c.enterScope() + + for _, p := range node.Type.Params.List { + s := c.symbolTable.Define(p.Name) + + // function arguments is not assigned directly. + s.LocalAssigned = true + } + + if err := c.Compile(node.Body); err != nil { + return err + } + + // add OpReturn if function returns nothing + if !c.lastInstructionIs(OpReturnValue) && !c.lastInstructionIs(OpReturn) { + c.emit(node, OpReturn) + } + + freeSymbols := c.symbolTable.FreeSymbols() + numLocals := c.symbolTable.MaxSymbols() + instructions, sourceMap := c.leaveScope() + + for _, s := range freeSymbols { + switch s.Scope { + case ScopeLocal: + if !s.LocalAssigned { + // Here, the closure is capturing a local variable that's not yet assigned its value. + // One example is a local recursive function: + // + // func() { + // foo := func(x) { + // // .. + // return foo(x-1) + // } + // } + // + // which translate into + // + // 0000 GETL 0 + // 0002 CLOSURE ? 1 + // 0006 DEFL 0 + // + // . So the local variable (0) is being captured before it's assigned the value. + // + // Solution is to transform the code into something like this: + // + // func() { + // foo := undefined + // foo = func(x) { + // // .. + // return foo(x-1) + // } + // } + // + // that is equivalent to + // + // 0000 NULL + // 0001 DEFL 0 + // 0003 GETL 0 + // 0005 CLOSURE ? 1 + // 0009 SETL 0 + // + + c.emit(node, OpNull) + c.emit(node, OpDefineLocal, s.Index) + + s.LocalAssigned = true + } + + c.emit(node, OpGetLocal, s.Index) + case ScopeFree: + c.emit(node, OpGetFree, s.Index) + } + } + + compiledFunction := &objects.CompiledFunction{ + Instructions: instructions, + NumLocals: numLocals, + NumParameters: len(node.Type.Params.List), + SourceMap: sourceMap, + } + + if len(freeSymbols) > 0 { + c.emit(node, OpClosure, c.addConstant(compiledFunction), len(freeSymbols)) + } else { + c.emit(node, OpConstant, c.addConstant(compiledFunction)) + } + + case *ast.ReturnStmt: + if c.symbolTable.Parent(true) == nil { + // outside the function + return c.errorf(node, "return not allowed outside function") + } + + if node.Result == nil { + c.emit(node, OpReturn) + } else { + if err := c.Compile(node.Result); err != nil { + return err + } + + c.emit(node, OpReturnValue) + } + + case *ast.CallExpr: + if err := c.Compile(node.Func); err != nil { + return err + } + + for _, arg := range node.Args { + if err := c.Compile(arg); err != nil { + return err + } + } + + c.emit(node, OpCall, len(node.Args)) + + case *ast.ImportExpr: + if c.builtinModules[node.ModuleName] { + c.emit(node, OpConstant, c.addConstant(&objects.String{Value: node.ModuleName})) + c.emit(node, OpGetBuiltinModule) + } else { + userMod, err := c.compileModule(node) + if err != nil { + return err + } + + c.emit(node, OpConstant, c.addConstant(userMod)) + c.emit(node, OpCall, 0) + } + + case *ast.ExportStmt: + // export statement must be in top-level scope + if c.scopeIndex != 0 { + return c.errorf(node, "export not allowed inside function") + } + + // export statement is simply ignore when compiling non-module code + if c.parent == nil { + break + } + + if err := c.Compile(node.Result); err != nil { + return err + } + + c.emit(node, OpImmutable) + c.emit(node, OpReturnValue) + + case *ast.ErrorExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + + c.emit(node, OpError) + + case *ast.ImmutableExpr: + if err := c.Compile(node.Expr); err != nil { + return err + } + + c.emit(node, OpImmutable) + + case *ast.CondExpr: + if err := c.Compile(node.Cond); err != nil { + return err + } + + // first jump placeholder + jumpPos1 := c.emit(node, OpJumpFalsy, 0) + + if err := c.Compile(node.True); err != nil { + return err + } + + // second jump placeholder + jumpPos2 := c.emit(node, OpJump, 0) + + // update first jump offset + curPos := len(c.currentInstructions()) + c.changeOperand(jumpPos1, curPos) + + if err := c.Compile(node.False); err != nil { + return err + } + + // update second jump offset + curPos = len(c.currentInstructions()) + c.changeOperand(jumpPos2, curPos) + } + + return nil +} + +// Bytecode returns a compiled bytecode. +func (c *Compiler) Bytecode() *Bytecode { + return &Bytecode{ + FileSet: c.file.Set(), + MainFunction: &objects.CompiledFunction{ + Instructions: c.currentInstructions(), + SourceMap: c.currentSourceMap(), + }, + Constants: c.constants, + } +} + +// SetModuleLoader sets or replaces the current module loader. +// Note that the module loader is used for user modules, +// not for the standard modules. +func (c *Compiler) SetModuleLoader(moduleLoader ModuleLoader) { + c.moduleLoader = moduleLoader +} + +func (c *Compiler) fork(file *source.File, moduleName string, symbolTable *SymbolTable) *Compiler { + child := NewCompiler(file, symbolTable, nil, c.builtinModules, c.trace) + child.moduleName = moduleName // name of the module to compile + child.parent = c // parent to set to current compiler + child.moduleLoader = c.moduleLoader // share module loader + + return child +} + +func (c *Compiler) errorf(node ast.Node, format string, args ...interface{}) error { + return &Error{ + fileSet: c.file.Set(), + node: node, + error: fmt.Errorf(format, args...), + } +} + +func (c *Compiler) addConstant(o objects.Object) int { + if c.parent != nil { + // module compilers will use their parent's constants array + return c.parent.addConstant(o) + } + + c.constants = append(c.constants, o) + + if c.trace != nil { + c.printTrace(fmt.Sprintf("CONST %04d %s", len(c.constants)-1, o)) + } + + return len(c.constants) - 1 +} + +func (c *Compiler) addInstruction(b []byte) int { + posNewIns := len(c.currentInstructions()) + + c.scopes[c.scopeIndex].instructions = append(c.currentInstructions(), b...) + + return posNewIns +} + +func (c *Compiler) setLastInstruction(op Opcode, pos int) { + c.scopes[c.scopeIndex].lastInstructions[1] = c.scopes[c.scopeIndex].lastInstructions[0] + + c.scopes[c.scopeIndex].lastInstructions[0].Opcode = op + c.scopes[c.scopeIndex].lastInstructions[0].Position = pos +} + +func (c *Compiler) lastInstructionIs(op Opcode) bool { + if len(c.currentInstructions()) == 0 { + return false + } + + return c.scopes[c.scopeIndex].lastInstructions[0].Opcode == op +} + +func (c *Compiler) removeLastInstruction() { + lastPos := c.scopes[c.scopeIndex].lastInstructions[0].Position + + if c.trace != nil { + c.printTrace(fmt.Sprintf("DELET %s", + FormatInstructions(c.scopes[c.scopeIndex].instructions[lastPos:], lastPos)[0])) + } + + c.scopes[c.scopeIndex].instructions = c.currentInstructions()[:lastPos] + c.scopes[c.scopeIndex].lastInstructions[0] = c.scopes[c.scopeIndex].lastInstructions[1] +} + +func (c *Compiler) replaceInstruction(pos int, inst []byte) { + copy(c.currentInstructions()[pos:], inst) + + if c.trace != nil { + c.printTrace(fmt.Sprintf("REPLC %s", + FormatInstructions(c.scopes[c.scopeIndex].instructions[pos:], pos)[0])) + } +} + +func (c *Compiler) changeOperand(opPos int, operand ...int) { + op := Opcode(c.currentInstructions()[opPos]) + inst := MakeInstruction(op, operand...) + + c.replaceInstruction(opPos, inst) +} + +func (c *Compiler) emit(node ast.Node, opcode Opcode, operands ...int) int { + filePos := source.NoPos + if node != nil { + filePos = node.Pos() + } + + inst := MakeInstruction(opcode, operands...) + pos := c.addInstruction(inst) + c.scopes[c.scopeIndex].sourceMap[pos] = filePos + c.setLastInstruction(opcode, pos) + + if c.trace != nil { + c.printTrace(fmt.Sprintf("EMIT %s", + FormatInstructions(c.scopes[c.scopeIndex].instructions[pos:], pos)[0])) + } + + return pos +} + +func (c *Compiler) printTrace(a ...interface{}) { + const ( + dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . " + n = len(dots) + ) + + i := 2 * c.indent + for i > n { + _, _ = fmt.Fprint(c.trace, dots) + i -= n + } + _, _ = fmt.Fprint(c.trace, dots[0:i]) + _, _ = fmt.Fprintln(c.trace, a...) +} + +func trace(c *Compiler, msg string) *Compiler { + c.printTrace(msg, "{") + c.indent++ + + return c +} + +func un(c *Compiler) { + c.indent-- + c.printTrace("}") +} |