/*
 * Copyright 2016 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//
// Transforms code into SSA form. That ensures each variable has a
// single assignment.
//
// Note that "SSA form" usually means SSA + phis. This pass does not
// create phis, we still emit something in our AST, which does not
// have a phi instruction. What we emit when control flow joins
// require more than one input to a value is multiple assignments
// to the same local, with the SSA guarantee that one and only one
// of those assignments will arrive at the uses of that "merge local".
// TODO: consider adding a "proper" phi node to the AST, that passes
//       can utilize
//

#include <iterator>

#include "wasm.h"
#include "pass.h"
#include "wasm-builder.h"
#include "support/permutations.h"
#include "ast/literal-utils.h"
#include "ast/local-graph.h"

namespace wasm {

// A set we know is impossible / not in the ast
SetLocal IMPOSSIBLE_SET;

// Tracks assignments to locals, assuming single-assignment form, i.e.,
// each assignment creates a new variable.

struct SSAify : public Pass {
  bool isFunctionParallel() override { return true; }

  Pass* create() override { return new SSAify; }

  Module* module;
  Function* func;
  std::vector<Expression*> functionPrepends; // things we add to the function prologue

  void runFunction(PassRunner* runner, Module* module_, Function* func_) override {
    module = module_;
    func = func_;
    LocalGraph graph(func, module);
    // create new local indexes, one for each set
    createNewIndexes(graph);
    // we now know the sets for each get
    computeGetsAndPhis(graph);
    // add prepends to function
    addPrepends();
  }

  void createNewIndexes(LocalGraph& graph) {
    for (auto& pair : graph.locations) {
      auto* curr = pair.first;
      if (auto* set = curr->dynCast<SetLocal>()) {
        set->index = addLocal(func->getLocalType(set->index));
      }
    }
  }

  // After we traversed it all, we can compute gets and phis
  void computeGetsAndPhis(LocalGraph& graph) {
    for (auto& iter : graph.getSetses) {
      auto* get = iter.first;
      auto& sets = iter.second;
      if (sets.size() == 0) {
        continue; // unreachable, ignore
      }
      if (sets.size() == 1) {
        // TODO: add tests for this case
        // easy, just one set, use it's index
        auto* set = *sets.begin();
        if (set) {
          get->index = set->index;
        } else {
          // no set, assign param or zero
          if (func->isParam(get->index)) {
            // leave it, it's fine
          } else {
            // zero it out
            (*graph.locations[get]) = LiteralUtils::makeZero(get->type, *module);
          }
        }
        continue;
      }
      // more than 1 set, need a phi: a new local written to at each of the sets
      // if there is already a local with that property, reuse it
      auto gatherIndexes = [](SetLocal* set) {
        std::set<Index> ret;
        while (set) {
          ret.insert(set->index);
          set = set->value->dynCast<SetLocal>();
        }
        return ret;
      };
      auto indexes = gatherIndexes(*sets.begin());
      for (auto* set : sets) {
        if (set == *sets.begin()) continue;
        auto currIndexes = gatherIndexes(set);
        std::vector<Index> intersection;
        std::set_intersection(indexes.begin(), indexes.end(),
                              currIndexes.begin(), currIndexes.end(),
                              std::back_inserter(intersection));
        indexes.clear();
        if (intersection.empty()) break;
        // TODO: or keep sorted vectors?
        for (Index i : intersection) {
          indexes.insert(i);
        }
      }
      if (!indexes.empty()) {
        // we found an index, use it
        get->index = *indexes.begin();
      } else {
        // we need to create a local for this phi'ing
        auto new_ = addLocal(get->type);
        auto old = get->index;
        get->index = new_;
        Builder builder(*module);
        // write to the local in each of our sets
        for (auto* set : sets) {
          if (set) {
            // a set exists, just add a tee of its value
            set->value = builder.makeTeeLocal(
              new_,
              set->value
            );
          } else {
            // this is a param or the zero init value.
            if (func->isParam(old)) {
              // we add a set with the proper
              // param value at the beginning of the function
              auto* set = builder.makeSetLocal(
                new_,
                builder.makeGetLocal(old, func->getLocalType(old))
              );
              functionPrepends.push_back(set);
            } else {
              // this is a zero init, so we don't need to do anything actually
            }
          }
        }
      }
    }
  }

  Index addLocal(WasmType type) {
    return Builder::addVar(func, type);
  }

  void addPrepends() {
    if (functionPrepends.size() > 0) {
      Builder builder(*module);
      auto* block = builder.makeBlock();
      for (auto* pre : functionPrepends) {
        block->list.push_back(pre);
      }
      block->list.push_back(func->body);
      block->finalize(func->body->type);
      func->body = block;
    }
  }
};

Pass *createSSAifyPass() {
  return new SSAify();
}

} // namespace wasm

