Python Spark Job TypeError: Casting to unit-less dtype 'datetime64' is not supported. Pass e.g. 'datetime64[ns]' instead

41 views Asked by At

I have a spark job that reads in multiple txt files from a GCS bucket, makes some transformations and adds a couple columns then exports it to a new GCS bucket. The ingestions from the first bucket works fine as I can print the dataframe and see the column names and values. But once I try to export it to another GCS I get the following error:

TypeError: Casting to unit-less dtype 'datetime64' is not supported. Pass e.g. 'datetime64[ns]' instead

I dont see where this is coming from? Can someone help me find the root cause of this, thanks

Here's how the data looks after printing the dataframe, I add on the last four columns:

enter image description here

# declaring the packages
import json
import logging
import multiprocessing as mp
import re
import subprocess
import sys
import time
from configparser import ConfigParser
from datetime import date, datetime, timedelta
from multiprocessing.pool import Pool, ThreadPool
from typing import Any, Dict, List

import pandas as pd
from dateutil import parser
from google.cloud import bigquery, secretmanager, storage
from pyarrow import pandas_compat
from pyspark.conf import SparkConf
from pyspark.sql import Row, SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import col, concat, lit, regexp_replace, udf
from pyspark.sql.types import *
from pyspark.sql.types import (
    BooleanType,
    DateType,
    DecimalType,
    IntegerType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)


def get_logger():
    logger = logging.getLogger(__file__)
    if not logger.handlers:
        logger.setLevel(logging.DEBUG)
        file_handler = logging.StreamHandler()
        formatter = logging.Formatter(
            "%(asctime)s : %(levelname)s : %(name)s : %(message)s"
        )
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger


log = get_logger()

# schema for rou_billing table
rou_billing_schema = (
    StructType()
    .add("Business_line", StringType(), True)
    .add("Site", StringType(), True)
    .add("Customer", StringType(), True)
    .add("Bill_code", StringType(), True)
    .add("Date", DateType(), True)
    .add("Quantity", DecimalType(), True)
    .add("Unite_price", DecimalType(), True)
    .add("Invoice_no", IntegerType(), True)
    .add("partition_date", DateType(), True)
    .add("ingestion_time", TimestampType(), True)
    .add("file_name", StringType(), True)

)

# mapping table to input schema
schema_mapping = {"rou_billing": rou_billing_schema}


def main(argv):
    
    # Bucket where file was uploaded to
    source_location = argv[0]
    # Name of the feed file
    file_name = argv[1]
    # Name of the file that was uploaded
    table_name = argv[2]
    # Where to save the results of the transformation and to load the schema from
    destination_location = argv[3]
    
    

    global spark
    spark = (
        SparkSession.builder.master("yarn")
        .appName("parquet_transformation")
        .getOrCreate()
    )
    spark.conf.set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "CORRECTED")
    spark.conf.set("spark.sql.inMemoryColumnarStorage.compressed", True)
    spark.conf.set("spark.sql.inMemoryColumnarStorage.batchSize", 10000)
    sc = spark.sparkContext
    sc.setLogLevel("WARN")

    txt_files = sc.wholeTextFiles(source_location).keys().collect()
    print(txt_files)  # Prints a list of files under the bucket

    # get schema from schema_mapping dict based on file name
    input_schema = ""
    for table, blueprint in schema_mapping.items():
        if table == table_name:
            input_schema = blueprint
            break

    if input_schema:
        print(f"SCHEMA: {input_schema}")
    else:
        sys.exit("Invalid File. No schema found for the input file.")

    # Read file from bucket
    for txt_file in txt_files:
        try:
            df_file = (
                spark.read.format("com.databricks.spark.csv") 
            .option("delimiter", "\t") 
            .option("header", "true") 
            .load(txt_file) 
            )
        except Exception as e:
            log.error(f"Error opening excel file {txt_file}. Reason: {e}")
            raise e
        
        file_name = (txt_file.rsplit("/", 1)[-1]).rsplit(".")[0]

        # General column header cleaning.

        df_file = df_file.toDF(*[c.lower() for c in df_file.columns])
        df_file = df_file.toDF(*[c.replace('"', '') for c in df_file.columns])
        df_file = df_file.toDF(*[c.replace(' ', '_') for c in df_file.columns])
        df_file = df_file.toDF(*[c.lstrip('_') for c in df_file.columns])
        df_file = df_file.toDF(*[c.rstrip('_') for c in df_file.columns])
        df_file = df_file.toDF(*[c.replace('-', '_') for c in df_file.columns])

        # adding columns

        df_file = df_file.withColumn('partition_date', F.current_date())
        df_file = df_file.withColumn("ingestion_time", F.current_timestamp())
        df_file = df_file.withColumn('file_name', lit(file_name))

        print(df_file.printSchema())

        df_total = pd.DataFrame()
        df_total = df_file

        print(f"Showing File {file_name}")
        df_total.show()

        des_loc = destination_location + "/transform/" + file_name
        
        df_total.toPandas().to_csv(des_loc)

    print(f"Data has successfully been loaded into {destination_location}")

    spark.stop()

if __name__ == "__main__":
    main(sys.argv[1:])
0

There are 0 answers