diff --git a/extensions/googletransit/setup_extension.py b/extensions/googletransit/setup_extension.py index c3ed3fe6..6c0bef61 100644 --- a/extensions/googletransit/setup_extension.py +++ b/extensions/googletransit/setup_extension.py @@ -21,6 +21,7 @@ from . import fareattribute from . import route from . import stop +from . import transfer def GetGtfsFactory(factory = None): if not factory: @@ -38,4 +39,7 @@ def GetGtfsFactory(factory = None): # Stop class extension factory.UpdateClass('Stop', stop.Stop) + # Transfer class extension + factory.UpdateClass('Transfer', transfer.Transfer) + return factory diff --git a/extensions/googletransit/transfer.py b/extensions/googletransit/transfer.py new file mode 100644 index 00000000..2a2ad2d0 --- /dev/null +++ b/extensions/googletransit/transfer.py @@ -0,0 +1,60 @@ +#!/usr/bin/python2.5 + +# Copyright (C) 2011 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import transitfeed +import transitfeed.util as util + +class Transfer(transitfeed.Transfer): + """Extension of transitfeed.Transfer: + - Adding fields 'from_route_id', to_route_id', 'from_trip_id', 'to_trip_id' + See propposal at + https://developers.google.com/transit/gtfs/reference/gtfs-extensions#TripToTripTransfers + """ + + _FIELD_NAMES = transitfeed.Transfer._FIELD_NAMES + [ 'from_route_id', 'to_route_id', 'from_trip_id', 'to_trip_id' ] + _ID_COLUMNS = transitfeed.Transfer._ID_COLUMNS + [ 'from_route_id', 'to_route_id', 'from_trip_id', 'to_trip_id' ] + + def ValidateFromRouteIdIsValid(self, problems): + if not util.IsEmpty(self.from_route_id) and self.from_route_id not in self._schedule.routes.keys(): + problems.InvalidValue('from_route_id', self.from_route_id) + return False + return True + + def ValidateToRouteIdIsValid(self, problems): + if not util.IsEmpty(self.to_route_id) and self.to_route_id not in self._schedule.routes.keys(): + problems.InvalidValue('to_route_id', self.to_route_id) + return False + return True + + def ValidateFromTripIdIsValid(self, problems): + if not util.IsEmpty(self.from_trip_id) and self.from_trip_id not in self._schedule.trips.keys(): + problems.InvalidValue('from_trip_id', self.from_trip_id) + return False + return True + + def ValidateToTripIdIsValid(self, problems): + if not util.IsEmpty(self.to_trip_id) and self.to_trip_id not in self._schedule.trips.keys(): + problems.InvalidValue('to_trip_id', self.to_trip_id) + return False + return True + + def ValidateAfterAdd(self, problems): + result = super(Transfer, self).ValidateAfterAdd(problems) + result = self.ValidateFromRouteIdIsValid(problems) and result + result = self.ValidateToRouteIdIsValid(problems) and result + result = self.ValidateFromTripIdIsValid(problems) and result + result = self.ValidateToTripIdIsValid(problems) and result + return result \ No newline at end of file diff --git a/tests/testgoogletransitextension.py b/tests/testgoogletransitextension.py index 9087b910..57133423 100644 --- a/tests/testgoogletransitextension.py +++ b/tests/testgoogletransitextension.py @@ -357,3 +357,53 @@ def testNotValidAgencyLang(self): self.assertTrue(e_msg.find('not valid') != -1, '%s should not be valid, is: %s' % (e.value, e_msg)) self.accumulator.AssertNoMoreExceptions() + +class TransferTestCase(ExtensionMemoryZipTestCase): + gtfs_factory = extensions.googletransit.GetGtfsFactory() + + def testNoErrorsTransferRouteIdTripColumnsNotPresent(self): + self.SetArchiveContents( + "stops.txt", + "stop_id,stop_name,stop_lat,stop_lon\n" + "BEATTY_AIRPORT,Airport,36.868446,-116.784582\n" + "BULLFROG,Bullfrog,36.88108,-116.81797\n" + "STAGECOACH,Stagecoach Hotel,36.915682,-116.751677\n" + "STAGECOACH_BAR,Stagecoach Bar,36.916,-116.752\n") + self.SetArchiveContents("transfers.txt", + "from_stop_id,to_stop_id,transfer_type,min_transfer_time\n" + "STAGECOACH,STAGECOACH_BAR,2,180\n") + self.SetArchiveContents( + "stop_times.txt", + "trip_id,arrival_time,departure_time,stop_id,stop_sequence\n" + "AB1,10:00:00,10:00:00,BEATTY_AIRPORT,1\n" + "AB1,10:20:00,10:20:00,BULLFROG,2\n" + "AB1,10:25:00,10:25:00,STAGECOACH,3\n" + "AB1,10:26:00,10:26:00,STAGECOACH_BAR,4\n") + + self.MakeLoaderAndLoad(self.problems, gtfs_factory=self.gtfs_factory) + self.accumulator.AssertNoMoreExceptions() + + def testTransferInvalidTripId(self): + self.SetArchiveContents( + "stops.txt", + "stop_id,stop_name,stop_lat,stop_lon\n" + "BEATTY_AIRPORT,Airport,36.868446,-116.784582\n" + "BULLFROG,Bullfrog,36.88108,-116.81797\n" + "STAGECOACH,Stagecoach Hotel,36.915682,-116.751677\n" + "STAGECOACH_BAR,Stagecoach Bar,36.916,-116.752\n") + self.SetArchiveContents("transfers.txt", + "from_stop_id,to_stop_id,transfer_type,min_transfer_time,from_route_id," + "to_route_id,from_trip_id,to_trip_id\n" + "STAGECOACH,STAGECOACH_BAR,2,180,,,,\n" + "STAGECOACH,STAGECOACH_BAR,2,180,,,AB1,INVALID_TRIP_ID\n") + self.SetArchiveContents( + "stop_times.txt", + "trip_id,arrival_time,departure_time,stop_id,stop_sequence\n" + "AB1,10:00:00,10:00:00,BEATTY_AIRPORT,1\n" + "AB1,10:20:00,10:20:00,BULLFROG,2\n" + "AB1,10:25:00,10:25:00,STAGECOACH,3\n" + "AB1,10:26:00,10:26:00,STAGECOACH_BAR,4\n") + + self.MakeLoaderAndLoad(self.problems, gtfs_factory=self.gtfs_factory) + self.accumulator.PopInvalidValue("to_trip_id") + self.accumulator.AssertNoMoreExceptions() diff --git a/transitfeed/gtfsfactory.py b/transitfeed/gtfsfactory.py index 76528b0f..e294fb77 100644 --- a/transitfeed/gtfsfactory.py +++ b/transitfeed/gtfsfactory.py @@ -51,8 +51,8 @@ def __init__(self): 'Stop': Stop, 'StopTime': StopTime, 'Route': Route, - 'Transfer': Transfer, 'Trip': Trip, + 'Transfer': Transfer, 'Schedule': Schedule, 'Loader': Loader } @@ -91,12 +91,11 @@ def __init__(self): 'routes.txt': { 'required': True, 'loading_order': 20, 'classes': ['Route']}, - 'transfers.txt': { 'required': False, 'loading_order': 30, - 'classes': ['Transfer']}, - - 'trips.txt': { 'required': True, 'loading_order': 40, + 'trips.txt': { 'required': True, 'loading_order': 30, 'classes': ['Trip']}, + 'transfers.txt': { 'required': False, 'loading_order': 40, + 'classes': ['Transfer']}, } def __getattr__(self, name):