import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Set;
import java.util.TreeSet;

 * Created by WangLei on 20-1-10.
public class TimeUtils {

    private static final Logger LOGGER = LoggerFactory.getLogger(TimeUtils.class);

    public static final String DATE_FORMAT = "yyyyMMdd";
    public static final String TIME_FORMAT = "yyyyMMdd HH:mm:ss";
    public static final String HOUR_TIME_FORMAT = "yyyyMMdd HH";

    public static final long TIME_DAY_MILLISECOND = 86400000;

     * timestamp -> ymd
     * @param timestamp
     * @return
    public static String timestamp2Ymd(long timestamp) {
        String format = "yyyyMMdd";
        return timestamp2Ymd(timestamp, format);

    public static String timestamp2Ymd(long timestamp, String format) {
        SimpleDateFormat sdf;
        try {
            if(String.valueOf(timestamp).length() == 10) {
                timestamp *= 1000;
            sdf = new SimpleDateFormat(format);
            return sdf.format(new Date(timestamp));
        } catch(Exception ex) {
            sdf = new SimpleDateFormat(DATE_FORMAT);
            try {
                return sdf.format(new Date(timestamp));
            } catch (Exception e){}
        return null;

    public static String timestamp2Hour(long timestamp) {
        String time = timestamp2Ymd(timestamp, TIME_FORMAT);
        return time.substring(9, 11);

     * ymd -> Date
     * @param ymd
     * @return
    public static Date ymd2Date(String ymd) {
        return ymd2Date(ymd, "yyyyMMdd");

    public static Date ymd2Date(String ymd, String format) {
        try {
            SimpleDateFormat sdf = new SimpleDateFormat(format);
            return sdf.parse(ymd);
        } catch(ParseException ex) {
            LOGGER.error("parse ymd to timestamp error!", ex);
        } catch (Exception ex) {
            LOGGER.error("there is some problem when transfer ymd2Date!", ex);
        return null;

     * ymd -> timestamp
     * @param ymd
     * @return
    public static long ymd2timestamp(String ymd) {
        return ymd2Date(ymd).getTime();

    public static String genLastDayStr() {
        return timestamp2Ymd(System.currentTimeMillis() + TIME_DAY_MILLISECOND * (-1));

     * get the datestr before or after the given datestr
     * attention transfer the num from int to long
     * @param ymd
     * @param num
     * @return
    public static String genDateAfterInterval(String ymd, int num) {
        long timestamp = ymd2timestamp(ymd);
        long resTimeStamp = timestamp + TIME_DAY_MILLISECOND * Long.valueOf(num);
        return timestamp2Ymd(resTimeStamp);

    public static String genLastDayStr(String ymd) {
        return genDateAfterInterval(ymd, -1);

    public static Set<String> genYmdSet(long beginTs, long endTs) {
        TreeSet ymdSet = new TreeSet();
        for(long ts = beginTs; ts <= endTs; ts += 86400000L) {
        return ymdSet;

    public static Set<String> genYmdSet(String beginYmd, String endYmd) {
        long beginTs = ymd2timestamp(beginYmd);
        long endTs = ymd2timestamp(endYmd);
        return genYmdSet(beginTs, endTs);

     * end between begin days
     * if begin or end is not number format or end < begin, return Integer.MIN_VALUE
     * @param begin
     * @param end
     * @return
    public static int getIntervalBetweenTwoDays(String begin, String end) {
        try {
            int begintmp = Integer.valueOf(begin), endtmp = Integer.valueOf(end);
            if(begintmp > endtmp) {
                LOGGER.error("we need end no smaller than end!");
                return Integer.MIN_VALUE;
            Date d1 = ymd2Date(begin);
            Date d2 = ymd2Date(end);
            Long mils = (d2.getTime() - d1.getTime()) / TIME_DAY_MILLISECOND;
            return mils.intValue();
        } catch (NumberFormatException numformatex) {
            return Integer.MIN_VALUE;


  * Created by WangLei on 20-1-13.
object DateSpec extends Enumeration {
    type DateSpec = Value

    val YMD , Y_M_D, YMD2 = Value


import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.joda.time.DateTime

  * Created by WangLei on 20-1-10.
object HDFSUtils {

    val conf = new Configuration()

    def delete(sc: SparkContext, path: String) = {
        FileSystem.get(sc.hadoopConfiguration).delete(new Path(path), true)

    def isExist(sc: SparkContext, path: String) = {
        FileSystem.get(sc.hadoopConfiguration).exists(new Path(path))

    def checkFileExist(conf: Configuration = conf, FileName: String): Boolean = {
        var isExist = false

        try {
            val hdfs = FileSystem.get(conf)
            val path = new Path(FileName)
            isExist = hdfs.exists(path)
        } catch {
            case e: Exception => e.printStackTrace()


    def latestMidPath(conf: Configuration, basePath: String): Option[String] = {
        val today = new Date
        latestMidPath(conf, basePath, new DateTime(today.getTime), 7)

    def latestMidPath(conf: Configuration, basePath: String, ymd: String) : Option[String] = {
        val timestamp = TimeUtils.ymd2timestamp(ymd)
        latestMidPath(conf, basePath, new DateTime(timestamp), 7, false, DateSpec.YMD2)

    def latestMidPath(conf: Configuration = conf, basePath: String, date: DateTime, limit: Int,with_success_file:Boolean = true,dateSpec: DateSpec = DateSpec.YMD): Option[String] = {
        for (i <- 0 to limit) {
            val day = date.minusDays(i)
            val path = dateSpec match {
                case DateSpec.YMD => basePath + "/date=%04d%02d%02d".format(day.getYear, day.getMonthOfYear, day.getDayOfMonth)
                case DateSpec.Y_M_D => basePath + "/year=%04d/month=%02d/day=%02d".format(day.getYear, day.getMonthOfYear, day.getDayOfMonth)
                case DateSpec.YMD2 => basePath + "/%04d%02d%02d".format(day.getYear, day.getMonthOfYear, day.getDayOfMonth)

            if (checkFileExist(conf, if(with_success_file) path + "/_SUCCESS" else path))
                return Some(path)



import org.apache.spark.SparkConf
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.SparkSession
import org.slf4j.LoggerFactory

import scala.collection.JavaConversions._

object AlsTraining {
    val logger = LoggerFactory.getLogger(this.getClass)
    val separator = "\t"
    def genUserItemRdd(spark: SparkSession, ymd: String) = {
        val baseinput = PathUtils.user_item_click_path
        val (yesterday, daybegin) = (TimeUtils.genLastDayStr(ymd), TimeUtils.genDateAfterInterval(ymd, -29))
        val days = TimeUtils.genYmdSet(daybegin, yesterday)
        // userid itemid clicknum
        var rdd = spark.sparkContext.textFile(baseinput + ymd)
            .map(x => {
                val l = x.split("\t")
                (l(0), l(1), l(2))
        for (day <- days) {
            val path = baseinput + day
            if (HDFSUtils.isExist(spark.sparkContext, path)) {
                val tmp = spark.sparkContext.textFile(path)
                    .map(x => {
                        val l = x.split("\t")
                        (l(0), l(1), l(2))
                rdd = rdd.union(tmp)
    def genUserItemIndex(spark: SparkSession, ymd: String) = {
        val rdd = genUserItemRdd(spark, ymd)
        val userindex = rdd.map(x => x._1).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
        val itemindex = rdd.map(x => x._2).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
        (userindex, itemindex)
    case class Rating(userid: Int, itemid: Int, rating: Float)
    def trainmodel(spark: SparkSession, ymd: String) = {
        import spark.implicits._
        val rdd = genUserItemRdd(spark, ymd)
        val userindexrdd = rdd.map(x => x._1).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
        val itemindexrdd = rdd.map(x => x._2).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
        val data = rdd.map(x => {
            val (userid, itemid, count) = (x._1, x._2, x._3.toInt)
            (userid + separator + itemid, count)
            .reduceByKey(_ + _)
            .map(x => {
                val (userid, itemid, count) = (x._1.split(separator)(0), x._1.split(separator)(1), x._2)
                (userid, itemid + separator + count)
            .map(x => {
                val (itemandcount, userindex) = (x._2._1, x._2._2)
                val (itemid, count) = (itemandcount.split(separator)(0), itemandcount.split(separator)(1))
                (itemid, userindex + separator + count)
            .map(x => {
                val (userandcount, itemindex) = (x._2._1, x._2._2)
                val (userindex, count) = (userandcount.split(separator)(0), userandcount.split(separator)(1))
                Rating(userindex.toInt, itemindex.toInt, count.toFloat)
        val Array(training, test) = data.randomSplit(Array(0.8, 0.2))
        val als = new ALS().setRank(128).setMaxIter(8).setRegParam(0.01).
        val model = als.fit(training)
        val predictions = model.transform(test)
        val evaluator = new RegressionEvaluator()
        val rmse = evaluator.evaluate(predictions)
        logger.error("root-mean-square error is: {}", rmse)
        val userindex2userid = userindexrdd.map(x => (x._2, x._1))
        val userfactors = model.userFactors.rdd.map(x => {
            val (userid, userfactor) = (x.getInt(0).toLong, x.getList(1).toArray().mkString(","))
            (userid, userfactor)
            .map(x => {
                val (userindex, userfactor, userid) = (x._1, x._2._1, x._2._2)
                (userindex, userid, userfactor)
            .sortBy(x => x._1)
            .map(x => "%s\t%s\t%s".format(x._1, x._2, x._3))
        val itemindex2itemid = itemindexrdd.map(x => (x._2, x._1))
        val itemfactors = model.itemFactors.rdd.map(x => {
            val (itemid, itemfactor) = (x.getInt(0).toLong, x.getList(1).toArray().mkString(","))
            (itemid, itemfactor)
            .map(x => {
                val (itemindex, itemfactor, itemid) = (x._1, x._2._1, x._2._2)
                (itemindex, itemid, itemfactor)
            .sortBy(x => x._1)
            .map(x => "%s\t%s\t%s".format(x._1, x._2, x._3))
        (userfactors, itemfactors)
    def main(args: Array[String]): Unit = {
        val (ymd, operation) = (args(0), args(1))
        val sparkConf = new SparkConf()
        sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
        sparkConf.setAppName("user-item-als-training" + ymd)
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()
        operation match {
            case "index" => {
                val useroutput = PathUtils.user_index_path + ymd
                val itemoutput = PathUtils.item_index_path + ymd
                val (userindex, itemindex) = genUserItemIndex(spark, ymd)
                userindex.repartition(1).sortBy(_._2).map(x => "%s\t%s".format(x._2, x._2)).saveAsTextFile(useroutput)
                itemindex.repartition(1).sortBy(_._2).map(x => "%s\t%s".format(x._2, x._2)).saveAsTextFile(itemoutput)
            case "model" => {
                val (userfactors, itemfactors) = trainmodel(spark, ymd)
                val user_embedding_path = PathUtils.user_factor_path + ymd
                val item_embedding_path = PathUtils.item_factor_path + ymd
                HDFSUtils.delete(spark.sparkContext, user_embedding_path)
                HDFSUtils.delete(spark.sparkContext, item_embedding_path)



