// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package appstate

import (
	"bytes"
	"crypto/sha256"
	"encoding/json"
	"fmt"

	"google.golang.org/protobuf/proto"

	waBinary "go.mau.fi/whatsmeow/binary"
	waProto "go.mau.fi/whatsmeow/binary/proto"
	"go.mau.fi/whatsmeow/store"
	"go.mau.fi/whatsmeow/util/cbcutil"
)

// PatchList represents a decoded response to getting app state patches from the WhatsApp servers.
type PatchList struct {
	Name           WAPatchName
	HasMorePatches bool
	Patches        []*waProto.SyncdPatch
	Snapshot       *waProto.SyncdSnapshot
}

// DownloadExternalFunc is a function that can download a blob of external app state patches.
type DownloadExternalFunc func(*waProto.ExternalBlobReference) ([]byte, error)

func parseSnapshotInternal(collection *waBinary.Node, downloadExternal DownloadExternalFunc) (*waProto.SyncdSnapshot, error) {
	snapshotNode := collection.GetChildByTag("snapshot")
	rawSnapshot, ok := snapshotNode.Content.([]byte)
	if snapshotNode.Tag != "snapshot" || !ok {
		return nil, nil
	}
	var snapshot waProto.ExternalBlobReference
	err := proto.Unmarshal(rawSnapshot, &snapshot)
	if err != nil {
		return nil, fmt.Errorf("failed to unmarshal snapshot: %w", err)
	}
	var rawData []byte
	rawData, err = downloadExternal(&snapshot)
	if err != nil {
		return nil, fmt.Errorf("failed to download external mutations: %w", err)
	}
	var downloaded waProto.SyncdSnapshot
	err = proto.Unmarshal(rawData, &downloaded)
	if err != nil {
		return nil, fmt.Errorf("failed to unmarshal mutation list: %w", err)
	}
	return &downloaded, nil
}

func parsePatchListInternal(collection *waBinary.Node, downloadExternal DownloadExternalFunc) ([]*waProto.SyncdPatch, error) {
	patchesNode := collection.GetChildByTag("patches")
	patchNodes := patchesNode.GetChildren()
	patches := make([]*waProto.SyncdPatch, 0, len(patchNodes))
	for i, patchNode := range patchNodes {
		rawPatch, ok := patchNode.Content.([]byte)
		if patchNode.Tag != "patch" || !ok {
			continue
		}
		var patch waProto.SyncdPatch
		err := proto.Unmarshal(rawPatch, &patch)
		if err != nil {
			return nil, fmt.Errorf("failed to unmarshal patch #%d: %w", i+1, err)
		}
		if patch.GetExternalMutations() != nil && downloadExternal != nil {
			var rawData []byte
			rawData, err = downloadExternal(patch.GetExternalMutations())
			if err != nil {
				return nil, fmt.Errorf("failed to download external mutations: %w", err)
			}
			var downloaded waProto.SyncdMutations
			err = proto.Unmarshal(rawData, &downloaded)
			if err != nil {
				return nil, fmt.Errorf("failed to unmarshal mutation list: %w", err)
			} else if len(downloaded.GetMutations()) == 0 {
				return nil, fmt.Errorf("didn't get any mutations from download")
			}
			patch.Mutations = downloaded.Mutations
		}
		patches = append(patches, &patch)
	}
	return patches, nil
}

// ParsePatchList will decode an XML node containing app state patches, including downloading any external blobs.
func ParsePatchList(node *waBinary.Node, downloadExternal DownloadExternalFunc) (*PatchList, error) {
	collection := node.GetChildByTag("sync", "collection")
	ag := collection.AttrGetter()
	snapshot, err := parseSnapshotInternal(&collection, downloadExternal)
	if err != nil {
		return nil, err
	}
	patches, err := parsePatchListInternal(&collection, downloadExternal)
	if err != nil {
		return nil, err
	}
	list := &PatchList{
		Name:           WAPatchName(ag.String("name")),
		HasMorePatches: ag.OptionalBool("has_more_patches"),
		Patches:        patches,
		Snapshot:       snapshot,
	}
	return list, ag.Error()
}

type patchOutput struct {
	RemovedMACs [][]byte
	AddedMACs   []store.AppStateMutationMAC
	Mutations   []Mutation
}

func (proc *Processor) decodeMutations(mutations []*waProto.SyncdMutation, out *patchOutput, validateMACs bool) error {
	for i, mutation := range mutations {
		keyID := mutation.GetRecord().GetKeyId().GetId()
		keys, err := proc.getAppStateKey(keyID)
		if err != nil {
			return fmt.Errorf("failed to get key %X to decode mutation: %w", keyID, err)
		}
		content := mutation.GetRecord().GetValue().GetBlob()
		content, valueMAC := content[:len(content)-32], content[len(content)-32:]
		if validateMACs {
			expectedValueMAC := generateContentMAC(mutation.GetOperation(), content, keyID, keys.ValueMAC)
			if !bytes.Equal(expectedValueMAC, valueMAC) {
				return fmt.Errorf("failed to verify mutation #%d: %w", i+1, ErrMismatchingContentMAC)
			}
		}
		iv, content := content[:16], content[16:]
		plaintext, err := cbcutil.Decrypt(keys.ValueEncryption, iv, content)
		if err != nil {
			return fmt.Errorf("failed to decrypt mutation #%d: %w", i+1, err)
		}
		var syncAction waProto.SyncActionData
		err = proto.Unmarshal(plaintext, &syncAction)
		if err != nil {
			return fmt.Errorf("failed to unmarshal mutation #%d: %w", i+1, err)
		}
		indexMAC := mutation.GetRecord().GetIndex().GetBlob()
		if validateMACs {
			expectedIndexMAC := concatAndHMAC(sha256.New, keys.Index, syncAction.Index)
			if !bytes.Equal(expectedIndexMAC, indexMAC) {
				return fmt.Errorf("failed to verify mutation #%d: %w", i+1, ErrMismatchingIndexMAC)
			}
		}
		var index []string
		err = json.Unmarshal(syncAction.GetIndex(), &index)
		if err != nil {
			return fmt.Errorf("failed to unmarshal index of mutation #%d: %w", i+1, err)
		}
		if mutation.GetOperation() == waProto.SyncdMutation_REMOVE {
			out.RemovedMACs = append(out.RemovedMACs, indexMAC)
		} else if mutation.GetOperation() == waProto.SyncdMutation_SET {
			out.AddedMACs = append(out.AddedMACs, store.AppStateMutationMAC{
				IndexMAC: indexMAC,
				ValueMAC: valueMAC,
			})
		}
		out.Mutations = append(out.Mutations, Mutation{
			Operation: mutation.GetOperation(),
			Action:    syncAction.GetValue(),
			Index:     index,
			IndexMAC:  indexMAC,
			ValueMAC:  valueMAC,
		})
	}
	return nil
}

func (proc *Processor) storeMACs(name WAPatchName, currentState HashState, out *patchOutput) {
	err := proc.Store.AppState.PutAppStateVersion(string(name), currentState.Version, currentState.Hash)
	if err != nil {
		proc.Log.Errorf("Failed to update app state version in the database: %v", err)
	}
	err = proc.Store.AppState.DeleteAppStateMutationMACs(string(name), out.RemovedMACs)
	if err != nil {
		proc.Log.Errorf("Failed to remove deleted mutation MACs from the database: %v", err)
	}
	err = proc.Store.AppState.PutAppStateMutationMACs(string(name), currentState.Version, out.AddedMACs)
	if err != nil {
		proc.Log.Errorf("Failed to insert added mutation MACs to the database: %v", err)
	}
}

func (proc *Processor) validateSnapshotMAC(name WAPatchName, currentState HashState, keyID, expectedSnapshotMAC []byte) (keys ExpandedAppStateKeys, err error) {
	keys, err = proc.getAppStateKey(keyID)
	if err != nil {
		err = fmt.Errorf("failed to get key %X to verify patch v%d MACs: %w", keyID, currentState.Version, err)
		return
	}
	snapshotMAC := currentState.generateSnapshotMAC(name, keys.SnapshotMAC)
	if !bytes.Equal(snapshotMAC, expectedSnapshotMAC) {
		err = fmt.Errorf("failed to verify patch v%d: %w", currentState.Version, ErrMismatchingLTHash)
	}
	return
}

func (proc *Processor) decodeSnapshot(name WAPatchName, ss *waProto.SyncdSnapshot, initialState HashState, validateMACs bool, newMutationsInput []Mutation) (newMutations []Mutation, currentState HashState, err error) {
	currentState = initialState
	currentState.Version = ss.GetVersion().GetVersion()

	encryptedMutations := make([]*waProto.SyncdMutation, len(ss.GetRecords()))
	for i, record := range ss.GetRecords() {
		encryptedMutations[i] = &waProto.SyncdMutation{
			Operation: waProto.SyncdMutation_SET.Enum(),
			Record:    record,
		}
	}

	var warn []error
	warn, err = currentState.updateHash(encryptedMutations, func(indexMAC []byte, maxIndex int) ([]byte, error) {
		return nil, nil
	})
	if len(warn) > 0 {
		proc.Log.Warnf("Warnings while updating hash for %s: %+v", name, warn)
	}
	if err != nil {
		err = fmt.Errorf("failed to update state hash: %w", err)
		return
	}

	if validateMACs {
		_, err = proc.validateSnapshotMAC(name, currentState, ss.GetKeyId().GetId(), ss.GetMac())
		if err != nil {
			return
		}
	}

	var out patchOutput
	out.Mutations = newMutationsInput
	err = proc.decodeMutations(encryptedMutations, &out, validateMACs)
	if err != nil {
		err = fmt.Errorf("failed to decode snapshot of v%d: %w", currentState.Version, err)
		return
	}
	proc.storeMACs(name, currentState, &out)
	newMutations = out.Mutations
	return
}

// DecodePatches will decode all the patches in a PatchList into a list of app state mutations.
func (proc *Processor) DecodePatches(list *PatchList, initialState HashState, validateMACs bool) (newMutations []Mutation, currentState HashState, err error) {
	currentState = initialState
	var expectedLength int
	if list.Snapshot != nil {
		expectedLength = len(list.Snapshot.GetRecords())
	}
	for _, patch := range list.Patches {
		expectedLength += len(patch.GetMutations())
	}
	newMutations = make([]Mutation, 0, expectedLength)

	if list.Snapshot != nil {
		newMutations, currentState, err = proc.decodeSnapshot(list.Name, list.Snapshot, currentState, validateMACs, newMutations)
		if err != nil {
			return
		}
	}

	for _, patch := range list.Patches {
		version := patch.GetVersion().GetVersion()
		currentState.Version = version
		var warn []error
		warn, err = currentState.updateHash(patch.GetMutations(), func(indexMAC []byte, maxIndex int) ([]byte, error) {
			for i := maxIndex - 1; i >= 0; i-- {
				if bytes.Equal(patch.Mutations[i].GetRecord().GetIndex().GetBlob(), indexMAC) {
					value := patch.Mutations[i].GetRecord().GetValue().GetBlob()
					return value[len(value)-32:], nil
				}
			}
			// Previous value not found in current patch, look in the database
			return proc.Store.AppState.GetAppStateMutationMAC(string(list.Name), indexMAC)
		})
		if len(warn) > 0 {
			proc.Log.Warnf("Warnings while updating hash for %s: %+v", list.Name, warn)
		}
		if err != nil {
			err = fmt.Errorf("failed to update state hash: %w", err)
			return
		}

		if validateMACs {
			var keys ExpandedAppStateKeys
			keys, err = proc.validateSnapshotMAC(list.Name, currentState, patch.GetKeyId().GetId(), patch.GetSnapshotMac())
			if err != nil {
				return
			}
			patchMAC := generatePatchMAC(patch, list.Name, keys.PatchMAC)
			if !bytes.Equal(patchMAC, patch.GetPatchMac()) {
				err = fmt.Errorf("failed to verify patch v%d: %w", version, ErrMismatchingPatchMAC)
				return
			}
		}

		var out patchOutput
		out.Mutations = newMutations
		err = proc.decodeMutations(patch.GetMutations(), &out, validateMACs)
		if err != nil {
			return
		}
		proc.storeMACs(list.Name, currentState, &out)
		newMutations = out.Mutations
	}
	return
}