package tengo

import (
	"fmt"
	"io"
	"io/ioutil"
	"path/filepath"
	"reflect"
	"strings"

	"github.com/d5/tengo/v2/parser"
	"github.com/d5/tengo/v2/token"
)

// compilationScope represents a compiled instructions and the last two
// instructions that were emitted.
type compilationScope struct {
	Instructions []byte
	SymbolInit   map[string]bool
	SourceMap    map[int]parser.Pos
}

// loop represents a loop construct that the compiler uses to track the current
// loop.
type loop struct {
	Continues []int
	Breaks    []int
}

// CompilerError represents a compiler error.
type CompilerError struct {
	FileSet *parser.SourceFileSet
	Node    parser.Node
	Err     error
}

func (e *CompilerError) Error() string {
	filePos := e.FileSet.Position(e.Node.Pos())
	return fmt.Sprintf("Compile Error: %s\n\tat %s", e.Err.Error(), filePos)
}

// Compiler compiles the AST into a bytecode.
type Compiler struct {
	file            *parser.SourceFile
	parent          *Compiler
	modulePath      string
	constants       []Object
	symbolTable     *SymbolTable
	scopes          []compilationScope
	scopeIndex      int
	modules         *ModuleMap
	compiledModules map[string]*CompiledFunction
	allowFileImport bool
	loops           []*loop
	loopIndex       int
	trace           io.Writer
	indent          int
}

// NewCompiler creates a Compiler.
func NewCompiler(
	file *parser.SourceFile,
	symbolTable *SymbolTable,
	constants []Object,
	modules *ModuleMap,
	trace io.Writer,
) *Compiler {
	mainScope := compilationScope{
		SymbolInit: make(map[string]bool),
		SourceMap:  make(map[int]parser.Pos),
	}

	// symbol table
	if symbolTable == nil {
		symbolTable = NewSymbolTable()
	}

	// add builtin functions to the symbol table
	for idx, fn := range builtinFuncs {
		symbolTable.DefineBuiltin(idx, fn.Name)
	}

	// builtin modules
	if modules == nil {
		modules = NewModuleMap()
	}

	return &Compiler{
		file:            file,
		symbolTable:     symbolTable,
		constants:       constants,
		scopes:          []compilationScope{mainScope},
		scopeIndex:      0,
		loopIndex:       -1,
		trace:           trace,
		modules:         modules,
		compiledModules: make(map[string]*CompiledFunction),
	}
}

// Compile compiles the AST node.
func (c *Compiler) Compile(node parser.Node) error {
	if c.trace != nil {
		if node != nil {
			defer untracec(tracec(c, fmt.Sprintf("%s (%s)",
				node.String(), reflect.TypeOf(node).Elem().Name())))
		} else {
			defer untracec(tracec(c, "<nil>"))
		}
	}

	switch node := node.(type) {
	case *parser.File:
		for _, stmt := range node.Stmts {
			if err := c.Compile(stmt); err != nil {
				return err
			}
		}
	case *parser.ExprStmt:
		if err := c.Compile(node.Expr); err != nil {
			return err
		}
		c.emit(node, parser.OpPop)
	case *parser.IncDecStmt:
		op := token.AddAssign
		if node.Token == token.Dec {
			op = token.SubAssign
		}
		return c.compileAssign(node, []parser.Expr{node.Expr},
			[]parser.Expr{&parser.IntLit{Value: 1}}, op)
	case *parser.ParenExpr:
		if err := c.Compile(node.Expr); err != nil {
			return err
		}
	case *parser.BinaryExpr:
		if node.Token == token.LAnd || node.Token == token.LOr {
			return c.compileLogical(node)
		}
		if node.Token == token.Less {
			if err := c.Compile(node.RHS); err != nil {
				return err
			}
			if err := c.Compile(node.LHS); err != nil {
				return err
			}
			c.emit(node, parser.OpBinaryOp, int(token.Greater))
			return nil
		} else if node.Token == token.LessEq {
			if err := c.Compile(node.RHS); err != nil {
				return err
			}
			if err := c.Compile(node.LHS); err != nil {
				return err
			}
			c.emit(node, parser.OpBinaryOp, int(token.GreaterEq))
			return nil
		}
		if err := c.Compile(node.LHS); err != nil {
			return err
		}
		if err := c.Compile(node.RHS); err != nil {
			return err
		}

		switch node.Token {
		case token.Add:
			c.emit(node, parser.OpBinaryOp, int(token.Add))
		case token.Sub:
			c.emit(node, parser.OpBinaryOp, int(token.Sub))
		case token.Mul:
			c.emit(node, parser.OpBinaryOp, int(token.Mul))
		case token.Quo:
			c.emit(node, parser.OpBinaryOp, int(token.Quo))
		case token.Rem:
			c.emit(node, parser.OpBinaryOp, int(token.Rem))
		case token.Greater:
			c.emit(node, parser.OpBinaryOp, int(token.Greater))
		case token.GreaterEq:
			c.emit(node, parser.OpBinaryOp, int(token.GreaterEq))
		case token.Equal:
			c.emit(node, parser.OpEqual)
		case token.NotEqual:
			c.emit(node, parser.OpNotEqual)
		case token.And:
			c.emit(node, parser.OpBinaryOp, int(token.And))
		case token.Or:
			c.emit(node, parser.OpBinaryOp, int(token.Or))
		case token.Xor:
			c.emit(node, parser.OpBinaryOp, int(token.Xor))
		case token.AndNot:
			c.emit(node, parser.OpBinaryOp, int(token.AndNot))
		case token.Shl:
			c.emit(node, parser.OpBinaryOp, int(token.Shl))
		case token.Shr:
			c.emit(node, parser.OpBinaryOp, int(token.Shr))
		default:
			return c.errorf(node, "invalid binary operator: %s",
				node.Token.String())
		}
	case *parser.IntLit:
		c.emit(node, parser.OpConstant,
			c.addConstant(&Int{Value: node.Value}))
	case *parser.FloatLit:
		c.emit(node, parser.OpConstant,
			c.addConstant(&Float{Value: node.Value}))
	case *parser.BoolLit:
		if node.Value {
			c.emit(node, parser.OpTrue)
		} else {
			c.emit(node, parser.OpFalse)
		}
	case *parser.StringLit:
		if len(node.Value) > MaxStringLen {
			return c.error(node, ErrStringLimit)
		}
		c.emit(node, parser.OpConstant,
			c.addConstant(&String{Value: node.Value}))
	case *parser.CharLit:
		c.emit(node, parser.OpConstant,
			c.addConstant(&Char{Value: node.Value}))
	case *parser.UndefinedLit:
		c.emit(node, parser.OpNull)
	case *parser.UnaryExpr:
		if err := c.Compile(node.Expr); err != nil {
			return err
		}

		switch node.Token {
		case token.Not:
			c.emit(node, parser.OpLNot)
		case token.Sub:
			c.emit(node, parser.OpMinus)
		case token.Xor:
			c.emit(node, parser.OpBComplement)
		case token.Add:
			// do nothing?
		default:
			return c.errorf(node,
				"invalid unary operator: %s", node.Token.String())
		}
	case *parser.IfStmt:
		// open new symbol table for the statement
		c.symbolTable = c.symbolTable.Fork(true)
		defer func() {
			c.symbolTable = c.symbolTable.Parent(false)
		}()

		if node.Init != nil {
			if err := c.Compile(node.Init); err != nil {
				return err
			}
		}
		if err := c.Compile(node.Cond); err != nil {
			return err
		}

		// first jump placeholder
		jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0)
		if err := c.Compile(node.Body); err != nil {
			return err
		}
		if node.Else != nil {
			// second jump placeholder
			jumpPos2 := c.emit(node, parser.OpJump, 0)

			// update first jump offset
			curPos := len(c.currentInstructions())
			c.changeOperand(jumpPos1, curPos)
			if err := c.Compile(node.Else); err != nil {
				return err
			}

			// update second jump offset
			curPos = len(c.currentInstructions())
			c.changeOperand(jumpPos2, curPos)
		} else {
			// update first jump offset
			curPos := len(c.currentInstructions())
			c.changeOperand(jumpPos1, curPos)
		}
	case *parser.ForStmt:
		return c.compileForStmt(node)
	case *parser.ForInStmt:
		return c.compileForInStmt(node)
	case *parser.BranchStmt:
		if node.Token == token.Break {
			curLoop := c.currentLoop()
			if curLoop == nil {
				return c.errorf(node, "break not allowed outside loop")
			}
			pos := c.emit(node, parser.OpJump, 0)
			curLoop.Breaks = append(curLoop.Breaks, pos)
		} else if node.Token == token.Continue {
			curLoop := c.currentLoop()
			if curLoop == nil {
				return c.errorf(node, "continue not allowed outside loop")
			}
			pos := c.emit(node, parser.OpJump, 0)
			curLoop.Continues = append(curLoop.Continues, pos)
		} else {
			panic(fmt.Errorf("invalid branch statement: %s",
				node.Token.String()))
		}
	case *parser.BlockStmt:
		if len(node.Stmts) == 0 {
			return nil
		}

		c.symbolTable = c.symbolTable.Fork(true)
		defer func() {
			c.symbolTable = c.symbolTable.Parent(false)
		}()

		for _, stmt := range node.Stmts {
			if err := c.Compile(stmt); err != nil {
				return err
			}
		}
	case *parser.AssignStmt:
		err := c.compileAssign(node, node.LHS, node.RHS, node.Token)
		if err != nil {
			return err
		}
	case *parser.Ident:
		symbol, _, ok := c.symbolTable.Resolve(node.Name)
		if !ok {
			return c.errorf(node, "unresolved reference '%s'", node.Name)
		}

		switch symbol.Scope {
		case ScopeGlobal:
			c.emit(node, parser.OpGetGlobal, symbol.Index)
		case ScopeLocal:
			c.emit(node, parser.OpGetLocal, symbol.Index)
		case ScopeBuiltin:
			c.emit(node, parser.OpGetBuiltin, symbol.Index)
		case ScopeFree:
			c.emit(node, parser.OpGetFree, symbol.Index)
		}
	case *parser.ArrayLit:
		for _, elem := range node.Elements {
			if err := c.Compile(elem); err != nil {
				return err
			}
		}
		c.emit(node, parser.OpArray, len(node.Elements))
	case *parser.MapLit:
		for _, elt := range node.Elements {
			// key
			if len(elt.Key) > MaxStringLen {
				return c.error(node, ErrStringLimit)
			}
			c.emit(node, parser.OpConstant,
				c.addConstant(&String{Value: elt.Key}))

			// value
			if err := c.Compile(elt.Value); err != nil {
				return err
			}
		}
		c.emit(node, parser.OpMap, len(node.Elements)*2)

	case *parser.SelectorExpr: // selector on RHS side
		if err := c.Compile(node.Expr); err != nil {
			return err
		}
		if err := c.Compile(node.Sel); err != nil {
			return err
		}
		c.emit(node, parser.OpIndex)
	case *parser.IndexExpr:
		if err := c.Compile(node.Expr); err != nil {
			return err
		}
		if err := c.Compile(node.Index); err != nil {
			return err
		}
		c.emit(node, parser.OpIndex)
	case *parser.SliceExpr:
		if err := c.Compile(node.Expr); err != nil {
			return err
		}
		if node.Low != nil {
			if err := c.Compile(node.Low); err != nil {
				return err
			}
		} else {
			c.emit(node, parser.OpNull)
		}
		if node.High != nil {
			if err := c.Compile(node.High); err != nil {
				return err
			}
		} else {
			c.emit(node, parser.OpNull)
		}
		c.emit(node, parser.OpSliceIndex)
	case *parser.FuncLit:
		c.enterScope()

		for _, p := range node.Type.Params.List {
			s := c.symbolTable.Define(p.Name)

			// function arguments is not assigned directly.
			s.LocalAssigned = true
		}

		if err := c.Compile(node.Body); err != nil {
			return err
		}

		// code optimization
		c.optimizeFunc(node)

		freeSymbols := c.symbolTable.FreeSymbols()
		numLocals := c.symbolTable.MaxSymbols()
		instructions, sourceMap := c.leaveScope()

		for _, s := range freeSymbols {
			switch s.Scope {
			case ScopeLocal:
				if !s.LocalAssigned {
					// Here, the closure is capturing a local variable that's
					// not yet assigned its value. One example is a local
					// recursive function:
					//
					//   func() {
					//     foo := func(x) {
					//       // ..
					//       return foo(x-1)
					//     }
					//   }
					//
					// which translate into
					//
					//   0000 GETL    0
					//   0002 CLOSURE ?     1
					//   0006 DEFL    0
					//
					// . So the local variable (0) is being captured before
					// it's assigned the value.
					//
					// Solution is to transform the code into something like
					// this:
					//
					//   func() {
					//     foo := undefined
					//     foo = func(x) {
					//       // ..
					//       return foo(x-1)
					//     }
					//   }
					//
					// that is equivalent to
					//
					//   0000 NULL
					//   0001 DEFL    0
					//   0003 GETL    0
					//   0005 CLOSURE ?     1
					//   0009 SETL    0
					//
					c.emit(node, parser.OpNull)
					c.emit(node, parser.OpDefineLocal, s.Index)
					s.LocalAssigned = true
				}
				c.emit(node, parser.OpGetLocalPtr, s.Index)
			case ScopeFree:
				c.emit(node, parser.OpGetFreePtr, s.Index)
			}
		}

		compiledFunction := &CompiledFunction{
			Instructions:  instructions,
			NumLocals:     numLocals,
			NumParameters: len(node.Type.Params.List),
			VarArgs:       node.Type.Params.VarArgs,
			SourceMap:     sourceMap,
		}
		if len(freeSymbols) > 0 {
			c.emit(node, parser.OpClosure,
				c.addConstant(compiledFunction), len(freeSymbols))
		} else {
			c.emit(node, parser.OpConstant, c.addConstant(compiledFunction))
		}
	case *parser.ReturnStmt:
		if c.symbolTable.Parent(true) == nil {
			// outside the function
			return c.errorf(node, "return not allowed outside function")
		}

		if node.Result == nil {
			c.emit(node, parser.OpReturn, 0)
		} else {
			if err := c.Compile(node.Result); err != nil {
				return err
			}
			c.emit(node, parser.OpReturn, 1)
		}
	case *parser.CallExpr:
		if err := c.Compile(node.Func); err != nil {
			return err
		}
		for _, arg := range node.Args {
			if err := c.Compile(arg); err != nil {
				return err
			}
		}
		c.emit(node, parser.OpCall, len(node.Args))
	case *parser.ImportExpr:
		if node.ModuleName == "" {
			return c.errorf(node, "empty module name")
		}

		if mod := c.modules.Get(node.ModuleName); mod != nil {
			v, err := mod.Import(node.ModuleName)
			if err != nil {
				return err
			}

			switch v := v.(type) {
			case []byte: // module written in Tengo
				compiled, err := c.compileModule(node,
					node.ModuleName, node.ModuleName, v)
				if err != nil {
					return err
				}
				c.emit(node, parser.OpConstant, c.addConstant(compiled))
				c.emit(node, parser.OpCall, 0)
			case Object: // builtin module
				c.emit(node, parser.OpConstant, c.addConstant(v))
			default:
				panic(fmt.Errorf("invalid import value type: %T", v))
			}
		} else if c.allowFileImport {
			moduleName := node.ModuleName
			if !strings.HasSuffix(moduleName, ".tengo") {
				moduleName += ".tengo"
			}

			modulePath, err := filepath.Abs(moduleName)
			if err != nil {
				return c.errorf(node, "module file path error: %s",
					err.Error())
			}

			if err := c.checkCyclicImports(node, modulePath); err != nil {
				return err
			}

			moduleSrc, err := ioutil.ReadFile(moduleName)
			if err != nil {
				return c.errorf(node, "module file read error: %s",
					err.Error())
			}

			compiled, err := c.compileModule(node,
				moduleName, modulePath, moduleSrc)
			if err != nil {
				return err
			}
			c.emit(node, parser.OpConstant, c.addConstant(compiled))
			c.emit(node, parser.OpCall, 0)
		} else {
			return c.errorf(node, "module '%s' not found", node.ModuleName)
		}
	case *parser.ExportStmt:
		// export statement must be in top-level scope
		if c.scopeIndex != 0 {
			return c.errorf(node, "export not allowed inside function")
		}

		// export statement is simply ignore when compiling non-module code
		if c.parent == nil {
			break
		}
		if err := c.Compile(node.Result); err != nil {
			return err
		}
		c.emit(node, parser.OpImmutable)
		c.emit(node, parser.OpReturn, 1)
	case *parser.ErrorExpr:
		if err := c.Compile(node.Expr); err != nil {
			return err
		}
		c.emit(node, parser.OpError)
	case *parser.ImmutableExpr:
		if err := c.Compile(node.Expr); err != nil {
			return err
		}
		c.emit(node, parser.OpImmutable)
	case *parser.CondExpr:
		if err := c.Compile(node.Cond); err != nil {
			return err
		}

		// first jump placeholder
		jumpPos1 := c.emit(node, parser.OpJumpFalsy, 0)
		if err := c.Compile(node.True); err != nil {
			return err
		}

		// second jump placeholder
		jumpPos2 := c.emit(node, parser.OpJump, 0)

		// update first jump offset
		curPos := len(c.currentInstructions())
		c.changeOperand(jumpPos1, curPos)
		if err := c.Compile(node.False); err != nil {
			return err
		}

		// update second jump offset
		curPos = len(c.currentInstructions())
		c.changeOperand(jumpPos2, curPos)
	}
	return nil
}

// Bytecode returns a compiled bytecode.
func (c *Compiler) Bytecode() *Bytecode {
	return &Bytecode{
		FileSet: c.file.Set(),
		MainFunction: &CompiledFunction{
			Instructions: append(c.currentInstructions(), parser.OpSuspend),
			SourceMap:    c.currentSourceMap(),
		},
		Constants: c.constants,
	}
}

// EnableFileImport enables or disables module loading from local files.
// Local file modules are disabled by default.
func (c *Compiler) EnableFileImport(enable bool) {
	c.allowFileImport = enable
}

func (c *Compiler) compileAssign(
	node parser.Node,
	lhs, rhs []parser.Expr,
	op token.Token,
) error {
	numLHS, numRHS := len(lhs), len(rhs)
	if numLHS > 1 || numRHS > 1 {
		return c.errorf(node, "tuple assignment not allowed")
	}

	// resolve and compile left-hand side
	ident, selectors := resolveAssignLHS(lhs[0])
	numSel := len(selectors)

	if op == token.Define && numSel > 0 {
		// using selector on new variable does not make sense
		return c.errorf(node, "operator ':=' not allowed with selector")
	}

	symbol, depth, exists := c.symbolTable.Resolve(ident)
	if op == token.Define {
		if depth == 0 && exists {
			return c.errorf(node, "'%s' redeclared in this block", ident)
		}
		symbol = c.symbolTable.Define(ident)
	} else {
		if !exists {
			return c.errorf(node, "unresolved reference '%s'", ident)
		}
	}

	// +=, -=, *=, /=
	if op != token.Assign && op != token.Define {
		if err := c.Compile(lhs[0]); err != nil {
			return err
		}
	}

	// compile RHSs
	for _, expr := range rhs {
		if err := c.Compile(expr); err != nil {
			return err
		}
	}

	switch op {
	case token.AddAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Add))
	case token.SubAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Sub))
	case token.MulAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Mul))
	case token.QuoAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Quo))
	case token.RemAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Rem))
	case token.AndAssign:
		c.emit(node, parser.OpBinaryOp, int(token.And))
	case token.OrAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Or))
	case token.AndNotAssign:
		c.emit(node, parser.OpBinaryOp, int(token.AndNot))
	case token.XorAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Xor))
	case token.ShlAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Shl))
	case token.ShrAssign:
		c.emit(node, parser.OpBinaryOp, int(token.Shr))
	}

	// compile selector expressions (right to left)
	for i := numSel - 1; i >= 0; i-- {
		if err := c.Compile(selectors[i]); err != nil {
			return err
		}
	}

	switch symbol.Scope {
	case ScopeGlobal:
		if numSel > 0 {
			c.emit(node, parser.OpSetSelGlobal, symbol.Index, numSel)
		} else {
			c.emit(node, parser.OpSetGlobal, symbol.Index)
		}
	case ScopeLocal:
		if numSel > 0 {
			c.emit(node, parser.OpSetSelLocal, symbol.Index, numSel)
		} else {
			if op == token.Define && !symbol.LocalAssigned {
				c.emit(node, parser.OpDefineLocal, symbol.Index)
			} else {
				c.emit(node, parser.OpSetLocal, symbol.Index)
			}
		}

		// mark the symbol as local-assigned
		symbol.LocalAssigned = true
	case ScopeFree:
		if numSel > 0 {
			c.emit(node, parser.OpSetSelFree, symbol.Index, numSel)
		} else {
			c.emit(node, parser.OpSetFree, symbol.Index)
		}
	default:
		panic(fmt.Errorf("invalid assignment variable scope: %s",
			symbol.Scope))
	}
	return nil
}

func (c *Compiler) compileLogical(node *parser.BinaryExpr) error {
	// left side term
	if err := c.Compile(node.LHS); err != nil {
		return err
	}

	// jump position
	var jumpPos int
	if node.Token == token.LAnd {
		jumpPos = c.emit(node, parser.OpAndJump, 0)
	} else {
		jumpPos = c.emit(node, parser.OpOrJump, 0)
	}

	// right side term
	if err := c.Compile(node.RHS); err != nil {
		return err
	}

	c.changeOperand(jumpPos, len(c.currentInstructions()))
	return nil
}

func (c *Compiler) compileForStmt(stmt *parser.ForStmt) error {
	c.symbolTable = c.symbolTable.Fork(true)
	defer func() {
		c.symbolTable = c.symbolTable.Parent(false)
	}()

	// init statement
	if stmt.Init != nil {
		if err := c.Compile(stmt.Init); err != nil {
			return err
		}
	}

	// pre-condition position
	preCondPos := len(c.currentInstructions())

	// condition expression
	postCondPos := -1
	if stmt.Cond != nil {
		if err := c.Compile(stmt.Cond); err != nil {
			return err
		}
		// condition jump position
		postCondPos = c.emit(stmt, parser.OpJumpFalsy, 0)
	}

	// enter loop
	loop := c.enterLoop()

	// body statement
	if err := c.Compile(stmt.Body); err != nil {
		c.leaveLoop()
		return err
	}

	c.leaveLoop()

	// post-body position
	postBodyPos := len(c.currentInstructions())

	// post statement
	if stmt.Post != nil {
		if err := c.Compile(stmt.Post); err != nil {
			return err
		}
	}

	// back to condition
	c.emit(stmt, parser.OpJump, preCondPos)

	// post-statement position
	postStmtPos := len(c.currentInstructions())
	if postCondPos >= 0 {
		c.changeOperand(postCondPos, postStmtPos)
	}

	// update all break/continue jump positions
	for _, pos := range loop.Breaks {
		c.changeOperand(pos, postStmtPos)
	}
	for _, pos := range loop.Continues {
		c.changeOperand(pos, postBodyPos)
	}
	return nil
}

func (c *Compiler) compileForInStmt(stmt *parser.ForInStmt) error {
	c.symbolTable = c.symbolTable.Fork(true)
	defer func() {
		c.symbolTable = c.symbolTable.Parent(false)
	}()

	// for-in statement is compiled like following:
	//
	//   for :it := iterator(iterable); :it.next();  {
	//     k, v := :it.get()  // DEFINE operator
	//
	//     ... body ...
	//   }
	//
	// ":it" is a local variable but will be conflict with other user variables
	// because character ":" is not allowed.

	// init
	//   :it = iterator(iterable)
	itSymbol := c.symbolTable.Define(":it")
	if err := c.Compile(stmt.Iterable); err != nil {
		return err
	}
	c.emit(stmt, parser.OpIteratorInit)
	if itSymbol.Scope == ScopeGlobal {
		c.emit(stmt, parser.OpSetGlobal, itSymbol.Index)
	} else {
		c.emit(stmt, parser.OpDefineLocal, itSymbol.Index)
	}

	// pre-condition position
	preCondPos := len(c.currentInstructions())

	// condition
	//  :it.HasMore()
	if itSymbol.Scope == ScopeGlobal {
		c.emit(stmt, parser.OpGetGlobal, itSymbol.Index)
	} else {
		c.emit(stmt, parser.OpGetLocal, itSymbol.Index)
	}
	c.emit(stmt, parser.OpIteratorNext)

	// condition jump position
	postCondPos := c.emit(stmt, parser.OpJumpFalsy, 0)

	// enter loop
	loop := c.enterLoop()

	// assign key variable
	if stmt.Key.Name != "_" {
		keySymbol := c.symbolTable.Define(stmt.Key.Name)
		if itSymbol.Scope == ScopeGlobal {
			c.emit(stmt, parser.OpGetGlobal, itSymbol.Index)
		} else {
			c.emit(stmt, parser.OpGetLocal, itSymbol.Index)
		}
		c.emit(stmt, parser.OpIteratorKey)
		if keySymbol.Scope == ScopeGlobal {
			c.emit(stmt, parser.OpSetGlobal, keySymbol.Index)
		} else {
			c.emit(stmt, parser.OpDefineLocal, keySymbol.Index)
		}
	}

	// assign value variable
	if stmt.Value.Name != "_" {
		valueSymbol := c.symbolTable.Define(stmt.Value.Name)
		if itSymbol.Scope == ScopeGlobal {
			c.emit(stmt, parser.OpGetGlobal, itSymbol.Index)
		} else {
			c.emit(stmt, parser.OpGetLocal, itSymbol.Index)
		}
		c.emit(stmt, parser.OpIteratorValue)
		if valueSymbol.Scope == ScopeGlobal {
			c.emit(stmt, parser.OpSetGlobal, valueSymbol.Index)
		} else {
			c.emit(stmt, parser.OpDefineLocal, valueSymbol.Index)
		}
	}

	// body statement
	if err := c.Compile(stmt.Body); err != nil {
		c.leaveLoop()
		return err
	}

	c.leaveLoop()

	// post-body position
	postBodyPos := len(c.currentInstructions())

	// back to condition
	c.emit(stmt, parser.OpJump, preCondPos)

	// post-statement position
	postStmtPos := len(c.currentInstructions())
	c.changeOperand(postCondPos, postStmtPos)

	// update all break/continue jump positions
	for _, pos := range loop.Breaks {
		c.changeOperand(pos, postStmtPos)
	}
	for _, pos := range loop.Continues {
		c.changeOperand(pos, postBodyPos)
	}
	return nil
}

func (c *Compiler) checkCyclicImports(
	node parser.Node,
	modulePath string,
) error {
	if c.modulePath == modulePath {
		return c.errorf(node, "cyclic module import: %s", modulePath)
	} else if c.parent != nil {
		return c.parent.checkCyclicImports(node, modulePath)
	}
	return nil
}

func (c *Compiler) compileModule(
	node parser.Node,
	moduleName, modulePath string,
	src []byte,
) (*CompiledFunction, error) {
	if err := c.checkCyclicImports(node, modulePath); err != nil {
		return nil, err
	}

	compiledModule, exists := c.loadCompiledModule(modulePath)
	if exists {
		return compiledModule, nil
	}

	modFile := c.file.Set().AddFile(moduleName, -1, len(src))
	p := parser.NewParser(modFile, src, nil)
	file, err := p.ParseFile()
	if err != nil {
		return nil, err
	}

	// inherit builtin functions
	symbolTable := NewSymbolTable()
	for _, sym := range c.symbolTable.BuiltinSymbols() {
		symbolTable.DefineBuiltin(sym.Index, sym.Name)
	}

	// no global scope for the module
	symbolTable = symbolTable.Fork(false)

	// compile module
	moduleCompiler := c.fork(modFile, modulePath, symbolTable)
	if err := moduleCompiler.Compile(file); err != nil {
		return nil, err
	}

	// code optimization
	moduleCompiler.optimizeFunc(node)
	compiledFunc := moduleCompiler.Bytecode().MainFunction
	compiledFunc.NumLocals = symbolTable.MaxSymbols()
	c.storeCompiledModule(modulePath, compiledFunc)
	return compiledFunc, nil
}

func (c *Compiler) loadCompiledModule(
	modulePath string,
) (mod *CompiledFunction, ok bool) {
	if c.parent != nil {
		return c.parent.loadCompiledModule(modulePath)
	}
	mod, ok = c.compiledModules[modulePath]
	return
}

func (c *Compiler) storeCompiledModule(
	modulePath string,
	module *CompiledFunction,
) {
	if c.parent != nil {
		c.parent.storeCompiledModule(modulePath, module)
	}
	c.compiledModules[modulePath] = module
}

func (c *Compiler) enterLoop() *loop {
	loop := &loop{}
	c.loops = append(c.loops, loop)
	c.loopIndex++
	if c.trace != nil {
		c.printTrace("LOOPE", c.loopIndex)
	}
	return loop
}

func (c *Compiler) leaveLoop() {
	if c.trace != nil {
		c.printTrace("LOOPL", c.loopIndex)
	}
	c.loops = c.loops[:len(c.loops)-1]
	c.loopIndex--
}

func (c *Compiler) currentLoop() *loop {
	if c.loopIndex >= 0 {
		return c.loops[c.loopIndex]
	}
	return nil
}

func (c *Compiler) currentInstructions() []byte {
	return c.scopes[c.scopeIndex].Instructions
}

func (c *Compiler) currentSourceMap() map[int]parser.Pos {
	return c.scopes[c.scopeIndex].SourceMap
}

func (c *Compiler) enterScope() {
	scope := compilationScope{
		SymbolInit: make(map[string]bool),
		SourceMap:  make(map[int]parser.Pos),
	}
	c.scopes = append(c.scopes, scope)
	c.scopeIndex++
	c.symbolTable = c.symbolTable.Fork(false)
	if c.trace != nil {
		c.printTrace("SCOPE", c.scopeIndex)
	}
}

func (c *Compiler) leaveScope() (
	instructions []byte,
	sourceMap map[int]parser.Pos,
) {
	instructions = c.currentInstructions()
	sourceMap = c.currentSourceMap()
	c.scopes = c.scopes[:len(c.scopes)-1]
	c.scopeIndex--
	c.symbolTable = c.symbolTable.Parent(true)
	if c.trace != nil {
		c.printTrace("SCOPL", c.scopeIndex)
	}
	return
}

func (c *Compiler) fork(
	file *parser.SourceFile,
	modulePath string,
	symbolTable *SymbolTable,
) *Compiler {
	child := NewCompiler(file, symbolTable, nil, c.modules, c.trace)
	child.modulePath = modulePath // module file path
	child.parent = c              // parent to set to current compiler
	return child
}

func (c *Compiler) error(node parser.Node, err error) error {
	return &CompilerError{
		FileSet: c.file.Set(),
		Node:    node,
		Err:     err,
	}
}

func (c *Compiler) errorf(
	node parser.Node,
	format string,
	args ...interface{},
) error {
	return &CompilerError{
		FileSet: c.file.Set(),
		Node:    node,
		Err:     fmt.Errorf(format, args...),
	}
}

func (c *Compiler) addConstant(o Object) int {
	if c.parent != nil {
		// module compilers will use their parent's constants array
		return c.parent.addConstant(o)
	}
	c.constants = append(c.constants, o)
	if c.trace != nil {
		c.printTrace(fmt.Sprintf("CONST %04d %s", len(c.constants)-1, o))
	}
	return len(c.constants) - 1
}

func (c *Compiler) addInstruction(b []byte) int {
	posNewIns := len(c.currentInstructions())
	c.scopes[c.scopeIndex].Instructions = append(
		c.currentInstructions(), b...)
	return posNewIns
}

func (c *Compiler) replaceInstruction(pos int, inst []byte) {
	copy(c.currentInstructions()[pos:], inst)
	if c.trace != nil {
		c.printTrace(fmt.Sprintf("REPLC %s",
			FormatInstructions(
				c.scopes[c.scopeIndex].Instructions[pos:], pos)[0]))
	}
}

func (c *Compiler) changeOperand(opPos int, operand ...int) {
	op := c.currentInstructions()[opPos]
	inst := MakeInstruction(op, operand...)
	c.replaceInstruction(opPos, inst)
}

// optimizeFunc performs some code-level optimization for the current function
// instructions. It also removes unreachable (dead code) instructions and adds
// "returns" instruction if needed.
func (c *Compiler) optimizeFunc(node parser.Node) {
	// any instructions between RETURN and the function end
	// or instructions between RETURN and jump target position
	// are considered as unreachable.

	// pass 1. identify all jump destinations
	dsts := make(map[int]bool)
	iterateInstructions(c.scopes[c.scopeIndex].Instructions,
		func(pos int, opcode parser.Opcode, operands []int) bool {
			switch opcode {
			case parser.OpJump, parser.OpJumpFalsy,
				parser.OpAndJump, parser.OpOrJump:
				dsts[operands[0]] = true
			}
			return true
		})

	// pass 2. eliminate dead code
	var newInsts []byte
	posMap := make(map[int]int) // old position to new position
	var dstIdx int
	var deadCode bool
	iterateInstructions(c.scopes[c.scopeIndex].Instructions,
		func(pos int, opcode parser.Opcode, operands []int) bool {
			switch {
			case opcode == parser.OpReturn:
				if deadCode {
					return true
				}
				deadCode = true
			case dsts[pos]:
				dstIdx++
				deadCode = false
			case deadCode:
				return true
			}
			posMap[pos] = len(newInsts)
			newInsts = append(newInsts,
				MakeInstruction(opcode, operands...)...)
			return true
		})

	// pass 3. update jump positions
	var lastOp parser.Opcode
	var appendReturn bool
	endPos := len(c.scopes[c.scopeIndex].Instructions)
	iterateInstructions(newInsts,
		func(pos int, opcode parser.Opcode, operands []int) bool {
			switch opcode {
			case parser.OpJump, parser.OpJumpFalsy, parser.OpAndJump,
				parser.OpOrJump:
				newDst, ok := posMap[operands[0]]
				if ok {
					copy(newInsts[pos:],
						MakeInstruction(opcode, newDst))
				} else if endPos == operands[0] {
					// there's a jump instruction that jumps to the end of
					// function compiler should append "return".
					appendReturn = true
				} else {
					panic(fmt.Errorf("invalid jump position: %d", newDst))
				}
			}
			lastOp = opcode
			return true
		})
	if lastOp != parser.OpReturn {
		appendReturn = true
	}

	// pass 4. update source map
	newSourceMap := make(map[int]parser.Pos)
	for pos, srcPos := range c.scopes[c.scopeIndex].SourceMap {
		newPos, ok := posMap[pos]
		if ok {
			newSourceMap[newPos] = srcPos
		}
	}
	c.scopes[c.scopeIndex].Instructions = newInsts
	c.scopes[c.scopeIndex].SourceMap = newSourceMap

	// append "return"
	if appendReturn {
		c.emit(node, parser.OpReturn, 0)
	}
}

func (c *Compiler) emit(
	node parser.Node,
	opcode parser.Opcode,
	operands ...int,
) int {
	filePos := parser.NoPos
	if node != nil {
		filePos = node.Pos()
	}

	inst := MakeInstruction(opcode, operands...)
	pos := c.addInstruction(inst)
	c.scopes[c.scopeIndex].SourceMap[pos] = filePos
	if c.trace != nil {
		c.printTrace(fmt.Sprintf("EMIT  %s",
			FormatInstructions(
				c.scopes[c.scopeIndex].Instructions[pos:], pos)[0]))
	}
	return pos
}

func (c *Compiler) printTrace(a ...interface{}) {
	const (
		dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
		n    = len(dots)
	)

	i := 2 * c.indent
	for i > n {
		_, _ = fmt.Fprint(c.trace, dots)
		i -= n
	}
	_, _ = fmt.Fprint(c.trace, dots[0:i])
	_, _ = fmt.Fprintln(c.trace, a...)
}

func resolveAssignLHS(
	expr parser.Expr,
) (name string, selectors []parser.Expr) {
	switch term := expr.(type) {
	case *parser.SelectorExpr:
		name, selectors = resolveAssignLHS(term.Expr)
		selectors = append(selectors, term.Sel)
		return
	case *parser.IndexExpr:
		name, selectors = resolveAssignLHS(term.Expr)
		selectors = append(selectors, term.Index)
	case *parser.Ident:
		name = term.Name
	}
	return
}

func iterateInstructions(
	b []byte,
	fn func(pos int, opcode parser.Opcode, operands []int) bool,
) {
	for i := 0; i < len(b); i++ {
		numOperands := parser.OpcodeOperands[b[i]]
		operands, read := parser.ReadOperands(numOperands, b[i+1:])
		if !fn(i, b[i], operands) {
			break
		}
		i += read
	}
}

func tracec(c *Compiler, msg string) *Compiler {
	c.printTrace(msg, "{")
	c.indent++
	return c
}

func untracec(c *Compiler) {
	c.indent--
	c.printTrace("}")
}