-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
212 lines (150 loc) · 6.83 KB
/
main.py
File metadata and controls
212 lines (150 loc) · 6.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import argparse
import io
import requests
import os
import shutil
import pandas as pd
import pyarrow.parquet as pq
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType, TimestampType
# Define the schema for the Parquet files
schema = StructType([
StructField("VendorID", LongType()),
StructField("tpep_pickup_datetime", TimestampType()),
StructField("tpep_dropoff_datetime", TimestampType()),
StructField("passenger_count", DoubleType()),
StructField("trip_distance", DoubleType()),
StructField("RatecodeID", DoubleType()),
StructField("store_and_fwd_flag", StringType()),
StructField("PULocationID", LongType()),
StructField("DOLocationID", LongType()),
StructField("payment_type", LongType()),
StructField("fare_amount", DoubleType()),
StructField("extra", DoubleType()),
StructField("mta_tax", DoubleType()),
StructField("tip_amount", DoubleType()),
StructField("tolls_amount", DoubleType()),
StructField("improvement_surcharge", DoubleType()),
StructField("total_amount", DoubleType()),
StructField("congestion_surcharge", DoubleType()),
StructField("airport_fee", DoubleType())
])
def create_spark_session(session):
"""Function to create spark session and set logging level"""
spark = SparkSession.builder.appName(session).getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
print(f"Spark session {session} created")
return spark
def remove_files_from_path(path):
""" Function to remove files from a folder"""
try:
file_list = os.listdir(path)
for file_name in file_list:
file_path = os.path.join(path, file_name)
os.remove(file_path)
print(f'Removing files from {path}')
except Exception as e:
print(f"Error: {str(e)}")
def download_parquet_files(start_date, end_date, download_path):
"""Function that downloads the Parquet files of TLC trip record data
and stores them in a dedicated path for further processing"""
try:
os.makedirs(download_path, exist_ok=True)
remove_files_from_path(download_path)
date_range = pd.date_range(start_date, end_date)
download_urls = [
f"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_{date.strftime('%Y-%m')}.parquet"
for date in date_range
]
for url in download_urls:
# Download the Parquet file from the URL
response = requests.get(url)
if response.status_code == 200:
# Save in memory and write to the local file
buffer_table = io.BytesIO(response.content)
table = pq.read_table(buffer_table)
local_file_path = os.path.join(download_path, os.path.basename(url))
pq.write_table(table, local_file_path)
print(f"Downloading {os.path.basename(url)}")
else:
print(f"Failed to download data from {url}")
except Exception as e:
print(f"Error: {str(e)}")
def process_parquet_data(folder_path):
sdf_list = []
try:
# Load the Parquet files from the local file system into a Spark DataFrame
files = [f for f in os.listdir(folder_path) if f.endswith(".parquet")]
for file_name in files:
file_path = os.path.join(folder_path, file_name)
# Read the Parquet file applying the custom schema
sdf = spark.read.parquet(file_path)
# Convert fields to our schema
sdf = sdf.withColumn("VendorID", col("VendorID").cast("bigint"))
# Add the DataFrame to the list
sdf_list.append(sdf)
# Append the DataFrames in the list
sdf = sdf_list[0]
print("Appending parquet files")
for sdf in sdf_list[1:]:
sdf = sdf.union(sdf)
sdf = sdf.orderBy(col("trip_distance").desc())
# Calculate 10% of the total and reassign the initial sdf
top_10per_rows = int(sdf.count() * 0.10)
sdf = sdf.limit(top_10per_rows)
return sdf
except Exception as e:
print(f"Error processing parquet data: {str(e)}")
return None
def data_quality_total_amount(sdf):
""" Check and correct data quality issues in the input DataFrame """
try:
# Check for null values in the "total_amount" column
null_check = sdf.filter(col("total_amount").isNull())
# Check for negative values in "total_amount" and correct other columns accordingly
negative_check = sdf.filter(col("total_amount") < 0)
if not null_check.isEmpty():
print("Null values found in the 'total_amount' column.")
if not negative_check.isEmpty():
print("Negative 'total_amount' values found. Correcting...")
# Define a list of columns to correct
columns_to_correct = [
"fare_amount",
"extra",
"mta_tax",
"tip_amount",
"tolls_amount",
"improvement_surcharge",
"total_amount"
]
# Multiply selected columns by -1 for rows with negative "total_amount"
for column in columns_to_correct:
sdf = sdf.withColumn(column, when(col("total_amount") < 0, col(column) * -1).otherwise(col(column)))
# Check data type of "total_amount" column
if sdf.schema["total_amount"].dataType != DoubleType():
print("The data type of 'total_amount' column is not DoubleType.")
return null_check, negative_check, sdf
except Exception as e:
print(f"Error in data quality check: {str(e)}")
return None, None, None
def parser():
"""Function to parse arguments on the script"""
parser = argparse.ArgumentParser(description="Download and process Yellow Taxi trip data.")
parser.add_argument("start_date", help="Start date in yyyy-MM format")
parser.add_argument("end_date", help="End date in yyyy-MM format")
parser.add_argument("--download_path", default="raw_data", help="Download path for Parquet files")
return parser.parse_args()
if __name__ == '__main__':
spark = create_spark_session("TripRecordData")
args = parser()
download_parquet_files(args.start_date, args.end_date, args.download_path)
processed_data = process_parquet_data(args.download_path)
null_check, negative_check, final_sdf = data_quality_total_amount(processed_data)
final_sdf.write.mode("overwrite").parquet('processed_data')
print("Writing processed files")
if null_check is not None:
null_check.write.mode("overwrite").parquet('processed_data/null_data')
if negative_check is not None:
negative_check.write.mode("overwrite").parquet('processed_data/negative_amount')
spark.stop()